
Guide to Retinal Blood Vessel Segmentation with U-Net in PyTorch
We are going to build a U-Net in Pytorch which contains pairs of encoders and decoders cascaded inside one another.
The Retina Blood Vessel dataset is used in medical imaging and computer vision for detecting and segmenting blood vessels in retinal images. These datasets are valuable for developing algorithms that assist in diagnosing and monitoring retinal and systemic diseases like diabetic retinopathy, glaucoma and hypertension.
We have 80 pairs of training images and their masks and 20 pairs of testing images and their masks. The task is to create a model and train it so that when the retinal image is fed to the trained model it will be able to generate the corresponding mask.
The following porting of this post contains the step by step guide for retinal image segmentation using Pytorch.
Downloading The Data From Kaggle and Saving It Locally
import kagglehub
import os
# Download latest version
path = kagglehub.dataset_download("abdallahwagih/retina-blood-vessel")
print("Path to dataset files:", path)
import shutil
# Specify the destination folder
destination_folder = os.path.expanduser("/content/drive/MyDrive/Colab Notebooks/Computer Vision/Retina/")
# Create the destination folder if it doesn't exist
os.makedirs(destination_folder, exist_ok=True)
# Move the downloaded files to the "downloads" folder
for item in os.listdir(path):
source = os.path.join(path, item)
destination = os.path.join(destination_folder, item)
shutil.move(source, destination)
print("Files moved to:", destination_folder)
Import
import os
import cv2
import time
import random
import numpy as np
from glob import glob
from tqdm import tqdm
from operator import add
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
Dataset Class
The DriveDataset
class is a custom PyTorch Dataset
designed for loading and preprocessing images and their corresponding segmentation masks. Here’s a brief explanation of its components:
Initialization (__init__
method):
- Takes two inputs:
images_path
(list of file paths to the images) andmasks_path
(list of file paths to the corresponding masks). - Stores the paths and calculates the total number of samples (
n_samples
).
Data Loading (__getitem__
method):
- Reads an image from the specified file path using OpenCV in color format (
cv2.IMREAD_COLOR
). - Normalizes the image pixel values to the range [0, 1] by dividing by 255.
- Rearranges the image dimensions to channel-first format (C, H, W) for compatibility with PyTorch.
- Converts the image to a PyTorch tensor (
torch.from_numpy
). - Reads the corresponding mask in grayscale format (
cv2.IMREAD_GRAYSCALE
). - Normalizes the mask pixel values to the range [0, 1].
- Expands the mask dimensions to add a channel dimension, converting it to shape (1, H, W).
- Converts the mask to a PyTorch tensor.
Dataset Size (__len__
method):
- Returns the total number of samples in the dataset.
class DriveDataset(Dataset):
def __init__(self,images_path,masks_path):
self.images_path = images_path
self.masks_path = masks_path
self.n_samples = len(images_path)
def __getitem__(self,index):
#reading images
image = cv2.imread(self.images_path[index],cv2.IMREAD_COLOR)
image = image/255
image = np.transpose(image,(2,0,1))
image = image.astype(np.float32)
image = torch.from_numpy(image)
#reading masks
mask = cv2.imread(self.masks_path[index],cv2.IMREAD_GRAYSCALE)
mask = mask/255
mask = np.expand_dims(mask,axis=0)
mask = mask.astype(np.float32)
mask = torch.from_numpy(mask)
return image,mask
def __len__(self):
return self.n_samples
Convolution Layer Class
The conv_block
class is a PyTorch nn.Module
that implements a basic convolutional block commonly used in neural networks for feature extraction. Here’s a concise breakdown:
Initialization (__init__
method):
- Takes two inputs:
in_c
: Number of input channels.out_c
: Number of output channels.
- Creates a 2D convolutional layer (
nn.Conv2d
) with a kernel size of 3 and padding of 1 (preserves input dimensions). - Defines a ReLU activation function (
nn.ReLU
) to introduce non-linearity.
Forward Pass (forward
method):
- Applies the convolutional layer (
self.conv1
) to the input tensor. - Passes the result through the ReLU activation function (
self.relu
). - Returns the activated output.
class conv_block(nn.Module):
def __init__(self,in_c,out_c):
super().__init__()
self.conv1 = nn.Conv2d(in_c,out_c,kernel_size=3,padding=1)
self.relu = nn.ReLU()
def forward(self,inputs):
x = self.conv1(inputs)
x = self.relu(x)
return x
The conv_block will serve fundamental component in Encoder and Decoder of U-net.
The Encoder Block
The encoder_block
class is a PyTorch nn.Module
that implements an encoding block, commonly used in encoder-decoder architectures like U-Net. Here’s a breakdown:
- Initialization (
__init__
method):- Takes two inputs:
in_c
: Number of input channels.out_c
: Number of output channels.
- Creates:
- A
conv_block
(defined earlier) for feature extraction, which applies a convolution followed by a ReLU activation. - A max pooling layer (
nn.MaxPool2d
) with a pool size of(2, 2)
to downsample the spatial dimensions by a factor of 2.
- A
- Takes two inputs:
- Forward Pass (
forward
method):- Processes the input through the convolutional block (
self.conv
) to extract features. - Applies the max pooling operation (
self.pool
) to reduce the spatial dimensions of the extracted features. - Returns two outputs:
x
: The feature map after the convolution block (used for skip connections in U-Net).p
: The downsampled feature map after pooling (used as input to the next layer in the encoder).
- Processes the input through the convolutional block (
class encoder_block(nn.Module):
def __init__(self,in_c,out_c):
super().__init__()
self.conv = conv_block(in_c,out_c)
self.pool = nn.MaxPool2d((2,2))
def forward(self,inputs):
x = self.conv(inputs)
p = self.pool(x)
return x,p
This block is used in the encoder part of U-Net architecture to progressively reduce the spatial resolution while capturing hierarchical features.The feature map x
is often used in skip connections, while p
is passed to the next layer in the encoder.
The Decoder Block
The decoder_block
class is a PyTorch module used in neural networks, particularly in encoder-decoder architectures like UNet. Here’s a brief description:
- Upsampling (
self.up
): Uses a transposed convolution (nn.ConvTranspose2d
) to upsample the input tensor (inputs
) from lower resolution to a higher resolution. This increases the spatial dimensions of the tensor. - Concatenation: Combines the upsampled tensor (
x
) with a skip connection tensor (skip
) along the channel dimension (axis=1
). This helps in recovering spatial information lost during downsampling in the encoder part of the network. - Convolution (
self.conv
): Applies a convolutional block (conv_block
) on the concatenated tensor to learn features from the combined input. Theconv_block
typically consists of layers like convolution, normalization, and activation.
class decoder_block(nn.Module):
def __init__(self,in_c,out_c):
super().__init__()
self.up = nn.ConvTranspose2d(in_c,out_c,kernel_size=2,stride=2,padding=0)
self.conv = conv_block(out_c+out_c,out_c)
def forward(self,inputs,skip):
x = self.up(inputs)
x = torch.cat([x,skip],axis=1)
x = self.conv(x)
return x
The decoder_block
is designed to gradually reconstruct the spatial resolution and feature details in the decoding phase of a neural network, integrating features from corresponding encoder layers via skip connections. This is commonly used in tasks like image segmentation or reconstruction.
The U-Net Architecture
The build_unet
class defines a PyTorch implementation of a U-Net model, commonly used in image segmentation tasks. Here’s a breakdown:
- Encoder:
- Composed of
encoder_block
s (self.e1
,self.e2
) that extract hierarchical features. - Each block outputs two tensors: a skip connection tensor (
s
) and a downsampled tensor (p
).
- Composed of
- Bottleneck:
- A
conv_block
(self.b
) is used to learn features at the lowest resolution (the “bottleneck”).
- A
- Decoder:
- Contains
decoder_block
s (self.d3
,self.d4
) to upsample the bottleneck features and merge them with corresponding skip connections from the encoder.
- Contains
- Classifier:
- A 1×1 convolution layer (
self.outputs
) to map the final feature maps to the desired number of output channels (e.g., 1 for binary segmentation).
- A 1×1 convolution layer (
class build_unet(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.e1 = encoder_block(3, 64)
self.e2 = encoder_block(64, 128)
# Bottleneck
self.b = conv_block(128,256)
# Decoder
self.d3 = decoder_block(256, 128)
self.d4 = decoder_block(128, 64)
# Classifier
self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
def forward(self, inputs):
# Encoder
s1, p1 = self.e1(inputs)
s2, p2 = self.e2(p1)
# Bottleneck
b = self.b(p2)
# Decoder
d3 = self.d3(b, s2)
d4 = self.d4(d3, s1)
outputs = self.outputs(d4)
return outputs
Dice Coefficient
The DiceLoss
class implements the Dice coefficient-based loss function, which is widely used for evaluating and optimizing models in segmentation tasks. Here’s a detailed breakdown:
- Initialization:
- Inherits from
nn.Module
. - Allows optional
weight
for class balancing andsize_average
, though they are not actively used in this implementation.
2. Forward Method:
- Activation:
- Applies the sigmoid function to the input (
inputs
) to map raw predictions to a range of [0, 1]. - This step is optional and should be commented out if the model’s output already includes a sigmoid layer.
- Applies the sigmoid function to the input (
- Flattening:
- Flattens both
inputs
andtargets
tensors into 1D arrays to simplify computation.
- Flattens both
- Intersection:
- Calculates the intersection between predictions and ground-truth (
inputs * targets
), summed over all elements.
- Calculates the intersection between predictions and ground-truth (
- Dice Coefficient:
- Uses the formula:
- Uses the formula:
- A smoothing constant (
smooth
) is added to prevent division by zero and stabilize training. - Loss:
- Returns 1−Dice, as higher Dice similarity corresponds to better predictions, and the loss needs to decrease during training.
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
The Dice loss focuses on overlap between predicted and ground-truth segmentation masks, making it especially suitable for imbalanced datasets, where the foreground (object of interest) is much smaller than the background.
Dice Binary Cross Entropy Coefficient
The DiceBCELoss
class is a combination loss function that merges the Dice loss with Binary Cross Entropy (BCE) loss. This combination leverages the strengths of both loss functions, making it well-suited for binary segmentation tasks.
Structure and Components
- Initialization:
- Inherits from
nn.Module
. - Allows optional
weight
andsize_average
parameters, though they are not actively used here.
- Inherits from
- Forward Method:
- Activation:
- Applies the sigmoid function to raw model outputs (
inputs
) to map predictions to probabilities in [0, 1]. - This step should be skipped if the model already includes a sigmoid activation layer.
- Applies the sigmoid function to raw model outputs (
- Flattening:
- Converts
inputs
andtargets
into 1D tensors to simplify element-wise operations.
- Converts
- Dice Loss:
- Calculates the Dice coefficient-based loss, which penalizes poor overlap between predictions and targets:
- A smoothing constant (
smooth
) is included to avoid division by zero.
- Calculates the Dice coefficient-based loss, which penalizes poor overlap between predictions and targets:
- Binary Cross Entropy (BCE):
- Computes the standard BCE loss, which measures the difference between predicted probabilities and ground-truth labels.
- Combined Loss:
- Adds Dice loss and BCE loss: Combined Loss=BCE+Dice Loss
- Activation:
class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
Dice_BCE = BCE + dice_loss
return Dice_BCE
This combined loss function works well in cases where using Dice loss alone might lead to unstable training, especially in early epochs. It balances the strengths of both components for robust segmentation performance.
Few Helper Function
Function to create seeds
def seeding(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
Function to create directory if it does not exist
def create_dir(path):
if not os.path.exists(path):
os.makedirs(path)
Function to calculate time for each epoch
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
Function to Train the Data
The train
function is a typical training loop for a PyTorch model. Here’s a concise explanation:
- Inputs:
model
: The neural network being trained.loader
: The data loader providing batches of input (x
) and target (y
) tensors.optimizer
: The optimizer used for updating the model’s parameters (e.g., Adam, SGD).loss_fn
: The loss function to measure prediction error (e.g., MSELoss, CrossEntropyLoss).device
: The device (CPU or GPU) where computations will occur.
2. Process:
- Initialization: Sets
epoch_loss
to 0 to accumulate loss over the epoch. - Training Mode: Puts the model in training mode (
model.train()
), enabling layers like dropout or batch normalization. - Batch Loop:
- Transfers inputs (
x
) and targets (y
) to the specified device. - Clears previous gradients (
optimizer.zero_grad()
). - Makes predictions (
y_pred
) by passingx
through the model. - Computes the loss between predictions and targets (
loss_fn(y_pred, y)
). - Backpropagates gradients (
loss.backward()
). - Updates the model parameters (
optimizer.step()
). - Accumulates the batch loss into
epoch_loss
.
- Transfers inputs (
- Epoch Loss: Divides the total loss by the number of batches (
len(loader)
) for the average epoch loss.
3. Output:
- Returns the average loss for the epoch, which is a measure of the model’s performance during training.
def train(model, loader, optimizer, loss_fn, device):
epoch_loss = 0.0
model.train()
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss = epoch_loss/len(loader)
return epoch_loss
Function to Evaluate the model
The evaluate
function is used to assess the performance of a PyTorch model on a validation or test dataset without updating the model’s parameters. Here’s a brief explanation:
Inputs:
model
: The trained or partially trained neural network to evaluate.loader
: The data loader providing batches of input (x
) and target (y
) tensors.loss_fn
: The loss function to compute prediction errors (e.g., MSELoss, CrossEntropyLoss).device
: The device (CPU or GPU) where computations will take place.
Process:
- Initialization: Sets
epoch_loss
to 0 for accumulating the loss over all batches. - Evaluation Mode: Puts the model in evaluation mode (
model.eval()
), disabling layers like dropout and batch normalization updates. - No Gradient Computation: Wraps the evaluation process in
torch.no_grad()
to save memory and computation by not calculating gradients. - Batch Loop:
- Transfers inputs (
x
) and targets (y
) to the specified device. - Passes inputs through the model to generate predictions (
y_pred
). - Computes the loss between predictions and targets using the loss function.
- Accumulates the batch loss into
epoch_loss
.
- Transfers inputs (
- Epoch Loss: Divides the total loss by the number of batches (
len(loader)
) to get the average loss for the epoch.
Output:
- Returns the average loss for the epoch, which indicates how well the model performs on the given dataset.
def evaluate(model, loader, loss_fn, device):
epoch_loss = 0.0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = model(x)
loss = loss_fn(y_pred, y)
epoch_loss += loss.item()
epoch_loss = epoch_loss / len(loader)
return epoch_loss
Get Data Paths
# Seeding
seeding(42)
# Create files directory to store checkpoint file
create_dir("files")
path = "/content/drive/MyDrive/Colab Notebooks/Computer Vision/Retina"
# Get data paths
train_x = sorted(glob(path+"/Data/train/image/*"))
train_y = sorted(glob(path+"/Data/train/mask/*"))
valid_x = sorted(glob(path+"/Data/test/image/*"))
valid_y = sorted(glob(path+"/Data/test/mask/*"))
data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
print(data_str)
Set Hyper Parameters
H = 512
W = 512
size = (H, W)
batch_size = 2
lr = 1e-4
checkpoint_path = "retina_checkpoint.pth"
Load Dataset
train_dataset = DriveDataset(train_x, train_y)
valid_dataset = DriveDataset(valid_x, valid_y)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
Instantiate U-net
This code sets up the training environment for a U-Net model, including the device, model, optimizer, learning rate scheduler, and loss function.
# Set cuda device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Build model
model = build_unet()
model = model.to(device)
# Set Optimizer and Loss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()
Load the trained model if exists
try:
# Load the model's state dictionary
model.load_state_dict(torch.load(checkpoint_path))
except Exception as e:
print(e)
Model Training
This code snippet implements a training loop for a PyTorch model with validation and checkpointing. Here’s a concise explanation:
Initialization:
best_valid_loss
: Set to a high initial value (float("inf")
) to track the best validation loss so far.num_epochs
: Number of epochs for training.start_time
: Captures the start time of the epoch for timing purposes.
Training and Validation:
- Training Loss: Calls the
train
function to perform backpropagation and update the model parameters using the training dataset (train_loader
). - Validation Loss: Calls the
evaluate
function to assess the model’s performance on the validation dataset (valid_loader
).
Model Checkpointing:
- If the validation loss improves (i.e.,
valid_loss < best_valid_loss
):- Logs the improvement.
- Saves the model’s state to the specified
checkpoint_path
usingtorch.save
. - Updates
best_valid_loss
to the newvalid_loss
.
Epoch Timing:
- Measures and formats the time taken for the epoch using the
epoch_time
function.
Logging Results:
- Logs the epoch number, time taken, training loss, and validation loss in a formatted string.
best_valid_loss = float("inf")
num_epochs = 3
for epoch in range(num_epochs):
start_time = time.time()
train_loss = train(model, train_loader, optimizer, loss_fn, device)
valid_loss = evaluate(model, valid_loader, loss_fn, device)
# Saving the model
if valid_loss < best_valid_loss:
print(f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}")
best_valid_loss = valid_loss
torch.save(model.state_dict(), checkpoint_path)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
data_str += f'\tTrain Loss: {train_loss:.3f}\n'
data_str += f'\tVal. Loss: {valid_loss:.3f}\n'
print(data_str)
Helper Functions for Testing
Function to Calculate Metric
def calculate_metrics(y_true, y_pred):
# Ground truth
y_true = y_true.cpu().numpy()
y_true = y_true > 0.5
y_true = y_true.astype(np.uint8)
y_true = y_true.reshape(-1)
# Prediction
y_pred = y_pred.cpu().numpy()
y_pred = y_pred > 0.5
y_pred = y_pred.astype(np.uint8)
y_pred = y_pred.reshape(-1)
score_jaccard = jaccard_score(y_true, y_pred)
score_f1 = f1_score(y_true, y_pred)
score_recall = recall_score(y_true, y_pred)
score_precision = precision_score(y_true, y_pred)
score_acc = accuracy_score(y_true, y_pred)
return [score_jaccard, score_f1, score_recall, score_precision, score_acc]
Function to parse the mask
def mask_parse(mask):
mask = np.expand_dims(mask, axis=-1) ## (512, 512, 1)
mask = np.concatenate([mask, mask, mask], axis=-1) ## (512, 512, 3)
return mask
Get Test Data
# Seeding
seeding(42)
# Folders
create_dir("results")
# Load dataset
test_x = sorted(glob(path+"/Data/test/image/*"))
test_y = sorted(glob(path+"/Data/test/mask/*"))
Load Checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_unet()
model = model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
Test
This code evaluates a segmentation model on a test dataset and saves the results with metrics. Here’s a step-by-step explanation:
Initialization:
metrics_score
: A list to accumulate metrics for evaluation (e.g., accuracy, precision, recall, F1 score, IoU).time_taken
: List to store inference time for each image, useful for FPS calculation.
Loop Over Test Dataset:
- Iterates through
test_x
(image paths) andtest_y
(mask paths) usingenumerate
andtqdm
for progress tracking.
Preprocessing:
- Input Image:
- Reads the image (
cv2.imread
). - Normalizes pixel values to
[0, 1]
. - Transposes dimensions from
(H, W, C)
to(C, H, W)
for PyTorch compatibility. - Expands to
(1, C, H, W)
to include batch dimension and converts to atorch.Tensor
. - Moves the tensor to the computation
device
.
- Reads the image (
- Ground Truth Mask:
- Reads the mask as a grayscale image.
- Normalizes pixel values to
[0, 1]
. - Reshapes to
(1, 1, H, W)
to include channel and batch dimensions. - Converts to
torch.Tensor
and moves to the computationdevice
.
Model Prediction:
- Uses
torch.no_grad()
to disable gradient computation during inference. - Records the time taken for prediction to calculate FPS later.
- Applies the model and computes the sigmoid activation to get pixel-wise probabilities.
- Thresholds the prediction to binary output (
> 0.5
).
Metrics Calculation:
- Calls
calculate_metrics
to evaluate the predictions against the ground truth mask. - Accumulates scores using
map(add, metrics_score, score)
.
Saving Results:
- Converts predicted and ground truth masks to visually interpretable formats (
mask_parse
). - Creates a concatenated image showing:
- Original image.
- Ground truth mask.
- Predicted mask.
- Saves the concatenated image using
cv2.imwrite
.
metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []
for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
# Extract the name
name = x.split("/")[-1].split(".")[0]
# Reading image
image = cv2.imread(x, cv2.IMREAD_COLOR) ## (512, 512, 3)
x = np.transpose(image, (2, 0, 1)) ## (3, 512, 512)
x = x / 255.0
x = np.expand_dims(x, axis=0) ## (1, 3, 512, 512)
x = x.astype(np.float32)
x = torch.from_numpy(x)
x = x.to(device)
# Reading mask
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE) ## (512, 512)
y = np.expand_dims(mask, axis=0) ## (1, 512, 512)
y = y / 255.0
y = np.expand_dims(y, axis=0) ## (1, 1, 512, 512)
y = y.astype(np.float32)
y = torch.from_numpy(y)
y = y.to(device)
with torch.no_grad():
# Prediction and Calculating FPS
start_time = time.time()
pred_y = model(x)
pred_y = torch.sigmoid(pred_y)
total_time = time.time() - start_time
time_taken.append(total_time)
score = calculate_metrics(y, pred_y)
metrics_score = list(map(add, metrics_score, score))
pred_y = pred_y[0].cpu().numpy() ## (1, 512, 512)
pred_y = np.squeeze(pred_y, axis=0) ## (512, 512)
pred_y = pred_y > 0.5
pred_y = np.array(pred_y, dtype=np.uint8)
# Saving masks
ori_mask = mask_parse(mask)
pred_y = mask_parse(pred_y)
line = np.ones((size[1], 10, 3)) * 128
# Concatenate images
cat_images = np.concatenate(
[image, line, ori_mask, line, pred_y * 255], axis=1
)
# Add labels to the images
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
thickness = 2
line_type = cv2.LINE_AA
# Add text to the original image
cv2.putText(cat_images, "Original", (10, 50), font, font_scale, font_color, thickness, line_type)
# Add text to the ground truth mask
cv2.putText(cat_images, "Ground Truth", (image.shape[1] + 20, 50), font, font_scale, font_color, thickness, line_type)
# Add text to the predicted mask
cv2.putText(cat_images, "Prediction", (2 * image.shape[1] + 40, 50), font, font_scale, font_color, thickness, line_type)
# Save the resulting image
cv2.imwrite(f"results/{name}.png", cat_images)
Test Evaluation
This snippet calculates and prints evaluation metrics and the inference speed (frames per second, FPS) of the model.
jaccard = metrics_score[0] / len(test_x)
f1 = metrics_score[1] / len(test_x)
recall = metrics_score[2] / len(test_x)
precision = metrics_score[3] / len(test_x)
acc = metrics_score[4] / len(test_x)
print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f}")
fps = 1 / np.mean(time_taken)
print("FPS: ", fps)
Show Result
import matplotlib.pyplot as plt
img = plt.imread(path+'/results/4.png')
plt.figure(figsize=(15, 8))
plt.imshow(img)
plt.axis('off')
plt.show()

Looking How our Model looks
from torchview import draw_graph
# Instantiate the model
unet = build_unet()
# Define an input tensor
input_tensor = torch.randn(1, 3, 224, 224) # Example input tensor with batch size 1 and 224x224 image
# Draw the graph
model_graph = draw_graph(unet, input_data=input_tensor, expand_nested=True)
# Render and save the graph using graphviz
model_graph.visual_graph.render(filename="unet_model_graph", format="png", cleanup=True)
model_graph.visual_graph

Alternative way
!pip3 install onnx
# Export model to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(unet, dummy_input, path+"unet.onnx", opset_version=11)
Open the “.onnx” file in Netron
