UNET SEGMENTATION

  • UNet is a fully convolutional network(FCN) that does image segmentation. Its goal is to predict each pixel's class.
  • UNet is built upon the FCN and modified in a way that it yields better segmentation in medical imaging.

1.1 Architecture

UNet Architecture has 3 parts:

  1. The Contracting/Downsampling Path
  2. Bottleneck
  3. The Expanding/Upsampling Path

Downsampling Path:

  1. It consists of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling.
  2. At each downsampling step we double the number of feature channels.

Upsampling Path:

  1. Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”), a concatenation with the correspondingly feature map from the downsampling path, and two 3x3 convolutions, each followed by a ReLU.

Skip Connection:

The skip connection from the downsampling path are concatenated with feature map during upsampling path. These skip connection provide local information to global information while upsampling.

Final Layer:

At the final layer a 1x1 convolution is used to map each feature vector to the desired number of classes.

1.2 Advantages

Advantages:

  1. The UNet combines the location information from the downsampling path to finally obtain a general information combining localisation and context, which is necessary to predict a good segmentation map.
  2. No Dense layer is used, so image sizes can be used.

1.3 Dataset

Link: Data Science Bowl 2018 Find the nuclei in divergent images to advance medical discovery

1.4 Code

In [1]:
## Imports
import os
import sys
import random

import numpy as np
import cv2
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from keras import backend as K

## Seeding 
seed = 2019
random.seed = seed
np.random.seed = seed
tf.seed = seed
Using TensorFlow backend.

Data Generator

In [2]:
from os import listdir
from os.path import isfile,join
class GetData():
	def __init__(self, data_dir):	
		images_list =[]		
		labels_list = []		
		self.source_list = []
		label_dir = os.path.join(data_dir, "Labels")
		image_dir = os.path.join(data_dir, "Images")
		self.image_size = 128
		examples = 0
		print("loading images")
		onlyImagefiles = [f for f in listdir(image_dir) if isfile(join(image_dir, f))]
		onlyLabelfiles = [f for f in listdir(label_dir) if isfile(join(label_dir, f))]
		onlyImagefiles.sort()
		onlyLabelfiles.sort()

		for i in range (len(onlyImagefiles)):
			image = cv2.imread(os.path.join(image_dir,onlyImagefiles[i]))
			#im = Image.open(os.path.join(label_dir,onlyLabelfiles[i]),cv2.IMREAD_GRAYSCALE)
			#label = np.array(im)
			label = cv2.imread(os.path.join(label_dir,onlyLabelfiles[i]),cv2.IMREAD_GRAYSCALE)
			#image= cv2.resize(image, (self.image_size, self.image_size))
			#label= cv2.resize(label, (self.image_size, self.image_size))
			image = image[96:224,96:224]
			label = label[96:224,96:224]
			#cv2.imwrite("Pre_"+str(i)+".jpg",label)
			#image = image[...,0][...,None]/255
			label = label>40
			image = image/255
			#image = image[...,None]
			label = label[...,None]
			label = label.astype(np.int32)
			#label = label*255
			#cv2.imwrite("Post_"+str(i)+".jpg",label)
			images_list.append(image)
			labels_list.append(label)
			examples = examples +1
							
		print("finished loading images")
		self.examples = examples
		print("Number of examples found: ", examples)
		self.images = np.array(images_list)
		self.labels = np.array(labels_list)
	def next_batch(self, batch_size):
	
		if len(self.source_list) < batch_size:
			new_source = list(range(self.examples))
			random.shuffle(new_source)
			self.source_list.extend(new_source)

		examples_idx = self.source_list[:batch_size]
		del self.source_list[:batch_size]

		return self.images[examples_idx,...], self.labels[examples_idx,...]
In [3]:
# Base Directory Directory 
base_dir= 'Data'

# Training and Test Directories 
train_dir = os.path.join(base_dir,'Train')
test_dir = os.path.join(base_dir,'Test')
real_dir = os.path.join(base_dir,'Real')

BATCH_SIZE = 1
BUFFER_SIZE = 1000
image_size = 128
EPOCHS = 20
def PreProcessImages():
	train_data = GetData(train_dir)
	test_data = GetData(test_dir)
	real_data = GetData(real_dir)

	return train_data,  test_data, real_data

Hyperparameters

In [4]:
image_size = 128
epochs = 50
batch_size = 8
In [5]:
train_data,  test_data, real_data= PreProcessImages()
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4)
ax = fig.add_subplot(1, 2, 1)
ax.imshow(train_data.images[3,:,:,:])
ax = fig.add_subplot(1, 2, 2)
ax.imshow(np.reshape(train_data.labels[3,:,:,:], (image_size, image_size)), cmap="gray")
loading images
finished loading images
Number of examples found:  582
loading images
finished loading images
Number of examples found:  20
loading images
finished loading images
Number of examples found:  20
Out[5]:
<matplotlib.image.AxesImage at 0x7f8260cf9450>
In [6]:
def ImportImages(train_data,  test_data):
	train_dataset=tf.data.Dataset.from_tensor_slices((train_data.images, train_data.labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
	test_dataset = tf.data.Dataset.from_tensor_slices((test_data.images, test_data.labels)).batch(BATCH_SIZE)
	
	return train_dataset, test_dataset

Different Convolutional Blocks

In [7]:
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)
    return c, p

def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
    us = keras.layers.UpSampling2D((2, 2))(x)
    concat = keras.layers.Concatenate()([us, skip])
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c

def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
    return c

UNet Model

In [8]:
def UNet():
    f = [16, 32, 64, 128, 256]
    inputs = keras.layers.Input((image_size, image_size, 3))
    
    p0 = inputs
    c1, p1 = down_block(p0, f[0]) #128 -> 64
    c2, p2 = down_block(p1, f[1]) #64 -> 32
    c3, p3 = down_block(p2, f[2]) #32 -> 16
    c4, p4 = down_block(p3, f[3]) #16->8
    
    bn = bottleneck(p4, f[4])
    
    u1 = up_block(bn, c4, f[3]) #8 -> 16
    u2 = up_block(u1, c3, f[2]) #16 -> 32
    u3 = up_block(u2, c2, f[1]) #32 -> 64
    u4 = up_block(u3, c1, f[0]) #64 -> 128
    
    outputs = keras.layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u4)
    model = keras.models.Model(inputs, outputs)
    return model
In [9]:
def f1_metric(y_true, y_pred):
    y_true = y_true >0.4
    y_pred = y_pred>0.4
    y_true = tf.dtypes.cast(y_true,tf.float32)
    y_pred = tf.dtypes.cast(y_pred,tf.float32)

    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val
In [10]:
def dice_coef(y_true, y_pred, smooth=1):
	y_true = y_true >0.4
	y_pred = y_pred>0.4
	y_true = tf.dtypes.cast(y_true,tf.float32)
	y_pred = tf.dtypes.cast(y_pred,tf.float32)

	intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
	return(2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)
In [11]:
model = UNet()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=[f1_metric,dice_coef])
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 16) 448         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 16) 2320        conv2d[0][0]                     
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 64, 64, 16)   0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d[0][0]              
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 32, 32, 64)   18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 32, 32, 64)   36928       conv2d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 16, 16, 128)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 16, 16, 128)  147584      conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 8, 8, 256)    295168      max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 256)    590080      conv2d_8[0][0]                   
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 16, 16, 256)  0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 16, 16, 384)  0           up_sampling2d[0][0]              
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 128)  442496      concatenate[0][0]                
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_10[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_12[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_2[0][0]            
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_3[0][0]            
                                                                 conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_16[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 128, 128, 1)  17          conv2d_17[0][0]                  
==================================================================================================
Total params: 1,962,625
Trainable params: 1,962,625
Non-trainable params: 0
__________________________________________________________________________________________________

Training the model

In [12]:
train_data,  test_data, real_data= PreProcessImages()
train_dataset, test_dataset= ImportImages(train_data,  test_data)

train_steps = len(train_data.labels)//batch_size
valid_steps = len(test_data.labels)//batch_size

#model.fit_generator(train_data, validation_data=test_data, steps_per_epoch=train_steps, validation_steps=valid_steps, 
 #                   epochs=epochs)
    
model_history = model.fit(train_data.images,train_data.labels,validation_split=0.3, epochs=epochs)
loading images
finished loading images
Number of examples found:  582
loading images
finished loading images
Number of examples found:  20
loading images
finished loading images
Number of examples found:  20
Train on 407 samples, validate on 175 samples
Epoch 1/50
407/407 [==============================] - 35s 85ms/sample - loss: 0.7600 - f1_metric: 0.2114 - dice_coef: 0.6658 - val_loss: 0.5871 - val_f1_metric: 0.1917 - val_dice_coef: 0.6995
Epoch 2/50
407/407 [==============================] - 33s 82ms/sample - loss: 0.5348 - f1_metric: 0.0172 - dice_coef: 0.8799 - val_loss: 0.4523 - val_f1_metric: 4.9930e-04 - val_dice_coef: 0.9158
Epoch 3/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.4400 - f1_metric: 5.1791e-04 - dice_coef: 0.9111 - val_loss: 0.3476 - val_f1_metric: 0.0000e+00 - val_dice_coef: 0.9239
Epoch 4/50
407/407 [==============================] - 39s 96ms/sample - loss: 0.3573 - f1_metric: 0.0105 - dice_coef: 0.9152 - val_loss: 0.3058 - val_f1_metric: 0.1061 - val_dice_coef: 0.9269
Epoch 5/50
407/407 [==============================] - 35s 87ms/sample - loss: 0.3099 - f1_metric: 0.3913 - dice_coef: 0.9335 - val_loss: 0.3033 - val_f1_metric: 0.5119 - val_dice_coef: 0.9408
Epoch 6/50
407/407 [==============================] - 39s 95ms/sample - loss: 0.3032 - f1_metric: 0.6039 - dice_coef: 0.9317 - val_loss: 0.2725 - val_f1_metric: 0.5793 - val_dice_coef: 0.9472
Epoch 7/50
407/407 [==============================] - 37s 90ms/sample - loss: 0.2763 - f1_metric: 0.6402 - dice_coef: 0.9415 - val_loss: 0.2845 - val_f1_metric: 0.5675 - val_dice_coef: 0.9157
Epoch 8/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.2765 - f1_metric: 0.6290 - dice_coef: 0.9335 - val_loss: 0.2500 - val_f1_metric: 0.6328 - val_dice_coef: 0.9511
Epoch 9/50
407/407 [==============================] - 37s 90ms/sample - loss: 0.2503 - f1_metric: 0.6760 - dice_coef: 0.9466 - val_loss: 0.2596 - val_f1_metric: 0.6682 - val_dice_coef: 0.9423
Epoch 10/50
407/407 [==============================] - 37s 90ms/sample - loss: 0.2285 - f1_metric: 0.6982 - dice_coef: 0.9502 - val_loss: 0.1828 - val_f1_metric: 0.7371 - val_dice_coef: 0.9601
Epoch 11/50
407/407 [==============================] - 36s 89ms/sample - loss: 0.1817 - f1_metric: 0.7699 - dice_coef: 0.9617 - val_loss: 0.1871 - val_f1_metric: 0.7333 - val_dice_coef: 0.9629
Epoch 12/50
407/407 [==============================] - 38s 92ms/sample - loss: 0.1836 - f1_metric: 0.7622 - dice_coef: 0.9617 - val_loss: 0.2170 - val_f1_metric: 0.6772 - val_dice_coef: 0.9572
Epoch 13/50
407/407 [==============================] - 36s 88ms/sample - loss: 0.1589 - f1_metric: 0.8008 - dice_coef: 0.9668 - val_loss: 0.1993 - val_f1_metric: 0.7193 - val_dice_coef: 0.9619
Epoch 14/50
407/407 [==============================] - 36s 88ms/sample - loss: 0.1519 - f1_metric: 0.8027 - dice_coef: 0.9678 - val_loss: 0.1714 - val_f1_metric: 0.7600 - val_dice_coef: 0.9652
Epoch 15/50
407/407 [==============================] - 36s 87ms/sample - loss: 0.1335 - f1_metric: 0.8333 - dice_coef: 0.9720 - val_loss: 0.1928 - val_f1_metric: 0.7503 - val_dice_coef: 0.9651
Epoch 16/50
407/407 [==============================] - 36s 87ms/sample - loss: 0.1266 - f1_metric: 0.8382 - dice_coef: 0.9731 - val_loss: 0.2394 - val_f1_metric: 0.7186 - val_dice_coef: 0.9628
Epoch 17/50
407/407 [==============================] - 38s 93ms/sample - loss: 0.1204 - f1_metric: 0.8482 - dice_coef: 0.9748 - val_loss: 0.1792 - val_f1_metric: 0.7581 - val_dice_coef: 0.9667
Epoch 18/50
407/407 [==============================] - 37s 91ms/sample - loss: 0.1119 - f1_metric: 0.8579 - dice_coef: 0.9762 - val_loss: 0.2039 - val_f1_metric: 0.7471 - val_dice_coef: 0.9664
Epoch 19/50
407/407 [==============================] - 37s 91ms/sample - loss: 0.1227 - f1_metric: 0.8435 - dice_coef: 0.9741 - val_loss: 0.1933 - val_f1_metric: 0.7515 - val_dice_coef: 0.9651
Epoch 20/50
407/407 [==============================] - 36s 89ms/sample - loss: 0.1090 - f1_metric: 0.8630 - dice_coef: 0.9766 - val_loss: 0.2324 - val_f1_metric: 0.7333 - val_dice_coef: 0.9648
Epoch 21/50
407/407 [==============================] - 40s 99ms/sample - loss: 0.1011 - f1_metric: 0.8709 - dice_coef: 0.9788 - val_loss: 0.1908 - val_f1_metric: 0.7668 - val_dice_coef: 0.9679
Epoch 22/50
407/407 [==============================] - 38s 93ms/sample - loss: 0.0880 - f1_metric: 0.8875 - dice_coef: 0.9809 - val_loss: 0.1554 - val_f1_metric: 0.7890 - val_dice_coef: 0.9705
Epoch 23/50
407/407 [==============================] - 39s 95ms/sample - loss: 0.0816 - f1_metric: 0.8958 - dice_coef: 0.9826 - val_loss: 0.2025 - val_f1_metric: 0.7781 - val_dice_coef: 0.9700
Epoch 24/50
407/407 [==============================] - 38s 94ms/sample - loss: 0.0768 - f1_metric: 0.8994 - dice_coef: 0.9832 - val_loss: 0.1802 - val_f1_metric: 0.7769 - val_dice_coef: 0.9703
Epoch 25/50
407/407 [==============================] - 40s 98ms/sample - loss: 0.0710 - f1_metric: 0.9075 - dice_coef: 0.9845 - val_loss: 0.2342 - val_f1_metric: 0.7647 - val_dice_coef: 0.9693
Epoch 26/50
407/407 [==============================] - 38s 93ms/sample - loss: 0.0791 - f1_metric: 0.8993 - dice_coef: 0.9827 - val_loss: 0.1672 - val_f1_metric: 0.7784 - val_dice_coef: 0.9707
Epoch 27/50
407/407 [==============================] - 37s 91ms/sample - loss: 0.0757 - f1_metric: 0.9025 - dice_coef: 0.9837 - val_loss: 0.2029 - val_f1_metric: 0.7684 - val_dice_coef: 0.9700
Epoch 28/50
407/407 [==============================] - 36s 88ms/sample - loss: 0.0762 - f1_metric: 0.9044 - dice_coef: 0.9841 - val_loss: 0.1603 - val_f1_metric: 0.7837 - val_dice_coef: 0.9708
Epoch 29/50
407/407 [==============================] - 39s 95ms/sample - loss: 0.0712 - f1_metric: 0.9097 - dice_coef: 0.9845 - val_loss: 0.1822 - val_f1_metric: 0.7908 - val_dice_coef: 0.9708
Epoch 30/50
407/407 [==============================] - 38s 93ms/sample - loss: 0.0602 - f1_metric: 0.9232 - dice_coef: 0.9869 - val_loss: 0.1939 - val_f1_metric: 0.7829 - val_dice_coef: 0.9702
Epoch 31/50
407/407 [==============================] - 36s 89ms/sample - loss: 0.0546 - f1_metric: 0.9274 - dice_coef: 0.9879 - val_loss: 0.2542 - val_f1_metric: 0.7720 - val_dice_coef: 0.9708
Epoch 32/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0538 - f1_metric: 0.9301 - dice_coef: 0.9882 - val_loss: 0.2463 - val_f1_metric: 0.7885 - val_dice_coef: 0.9716
Epoch 33/50
407/407 [==============================] - 35s 85ms/sample - loss: 0.0491 - f1_metric: 0.9363 - dice_coef: 0.9892 - val_loss: 0.2223 - val_f1_metric: 0.7939 - val_dice_coef: 0.9730
Epoch 34/50
407/407 [==============================] - 35s 85ms/sample - loss: 0.0453 - f1_metric: 0.9412 - dice_coef: 0.9899 - val_loss: 0.2263 - val_f1_metric: 0.7943 - val_dice_coef: 0.9723
Epoch 35/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0429 - f1_metric: 0.9433 - dice_coef: 0.9905 - val_loss: 0.2302 - val_f1_metric: 0.7946 - val_dice_coef: 0.9723
Epoch 36/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0450 - f1_metric: 0.9399 - dice_coef: 0.9899 - val_loss: 0.2243 - val_f1_metric: 0.7935 - val_dice_coef: 0.9728
Epoch 37/50
407/407 [==============================] - 35s 85ms/sample - loss: 0.0446 - f1_metric: 0.9417 - dice_coef: 0.9902 - val_loss: 0.2632 - val_f1_metric: 0.7853 - val_dice_coef: 0.9722
Epoch 38/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.0428 - f1_metric: 0.9433 - dice_coef: 0.9905 - val_loss: 0.2038 - val_f1_metric: 0.8077 - val_dice_coef: 0.9744
Epoch 39/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0403 - f1_metric: 0.9462 - dice_coef: 0.9909 - val_loss: 0.2249 - val_f1_metric: 0.7941 - val_dice_coef: 0.9724
Epoch 40/50
407/407 [==============================] - 35s 86ms/sample - loss: 0.0416 - f1_metric: 0.9458 - dice_coef: 0.9907 - val_loss: 0.3088 - val_f1_metric: 0.7616 - val_dice_coef: 0.9688
Epoch 41/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0382 - f1_metric: 0.9488 - dice_coef: 0.9915 - val_loss: 0.3123 - val_f1_metric: 0.7906 - val_dice_coef: 0.9726
Epoch 42/50
407/407 [==============================] - 35s 86ms/sample - loss: 0.0355 - f1_metric: 0.9534 - dice_coef: 0.9921 - val_loss: 0.2102 - val_f1_metric: 0.8088 - val_dice_coef: 0.9737
Epoch 43/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0403 - f1_metric: 0.9466 - dice_coef: 0.9908 - val_loss: 0.2485 - val_f1_metric: 0.8020 - val_dice_coef: 0.9737
Epoch 44/50
407/407 [==============================] - 34s 84ms/sample - loss: 0.0354 - f1_metric: 0.9508 - dice_coef: 0.9920 - val_loss: 0.2636 - val_f1_metric: 0.7983 - val_dice_coef: 0.9732
Epoch 45/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.0367 - f1_metric: 0.9506 - dice_coef: 0.9916 - val_loss: 0.2434 - val_f1_metric: 0.7988 - val_dice_coef: 0.9731
Epoch 46/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.0362 - f1_metric: 0.9509 - dice_coef: 0.9917 - val_loss: 0.2577 - val_f1_metric: 0.7906 - val_dice_coef: 0.9729
Epoch 47/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.0318 - f1_metric: 0.9576 - dice_coef: 0.9928 - val_loss: 0.3000 - val_f1_metric: 0.7959 - val_dice_coef: 0.9727
Epoch 48/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.0301 - f1_metric: 0.9587 - dice_coef: 0.9931 - val_loss: 0.3510 - val_f1_metric: 0.7873 - val_dice_coef: 0.9722
Epoch 49/50
407/407 [==============================] - 35s 87ms/sample - loss: 0.0287 - f1_metric: 0.9607 - dice_coef: 0.9934 - val_loss: 0.3219 - val_f1_metric: 0.7995 - val_dice_coef: 0.9731
Epoch 50/50
407/407 [==============================] - 34s 83ms/sample - loss: 0.0273 - f1_metric: 0.9625 - dice_coef: 0.9937 - val_loss: 0.3915 - val_f1_metric: 0.7936 - val_dice_coef: 0.9732

Testing the model

In [13]:
## Save the Weights
model.save_weights("UNetW.h5")

## Dataset for prediction
resultCross = model.predict(test_data.images)

resultSame = model.predict(real_data.images)


resultCross = resultCross > 0.4

resultSame = resultSame >0.4

score = model.evaluate(test_data.images,test_data.labels)

print("Cross Domain Loss: "+str(score[0]))

print("Cross Domain F1 score: "+str(score[1]))

print("Cross Domain Accuracy: "+str(score[2]))

score = model.evaluate(real_data.images,real_data.labels)

print("Real Domain Loss: "+str(score[0]))

print("Real Domain F1 score: "+str(score[1]))

print("Real Domain Accuracy: "+str(score[2]))
20/1 [========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 15ms/sample - loss: 1.7850 - f1_metric: 0.5947 - dice_coef: 0.9532
Cross Domain Loss: 1.7850427627563477
Cross Domain F1 score: 0.5947431
Cross Domain Accuracy: 0.953183
20/1 [========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 15ms/sample - loss: 0.3139 - f1_metric: 0.8353 - dice_coef: 0.9737
Real Domain Loss: 0.31388407945632935
Real Domain F1 score: 0.83528936
Real Domain Accuracy: 0.973703
In [14]:
for i in range (20):
    fig = plt.figure(figsize=(15,15))
    #fig.subplots(1,3,figsize=(15,15))
    fig.subplots_adjust(hspace=1, wspace=1)

    ax = fig.add_subplot(1, 3, 1)
    ax.imshow(test_data.images[i,:,:,:])
    ax.title.set_text("CD Image " + str(i))

    ax = fig.add_subplot(1, 3, 2)
    ax.imshow(np.reshape(test_data.labels[i,:,:,:]*255, (image_size, image_size)), cmap="gray")
    ax.title.set_text("CD Ground Truth "+ str(i))

    ax = fig.add_subplot(1, 3, 3)
    ax.imshow(np.reshape(resultCross[i]*255, (image_size, image_size)), cmap="gray")
    ax.title.set_text("CD Predicted "+ str(i))

    ax = fig.add_subplot(2, 3, 1)
    ax.imshow(real_data.images[i,:,:,:])
    ax.title.set_text("SD Image "+ str(i))

    ax = fig.add_subplot(2, 3, 2)
    ax.imshow(np.reshape(real_data.labels[i,:,:,:]*255, (image_size, image_size)), cmap="gray")
    ax.title.set_text("SD Ground Truth " + str(i))

    ax = fig.add_subplot(2, 3, 3)
    ax.imshow(np.reshape(resultSame[i]*255, (image_size, image_size)), cmap="gray")
    ax.title.set_text("SD Predicted "+ str(i))
    
    
In [ ]:
 
In [ ]: