Python
Guide to Retinal Blood Vessel Segmentation with U-Net in PyTorch

Guide to Retinal Blood Vessel Segmentation with U-Net in PyTorch

Code on Github

We are going to build a U-Net in Pytorch which contains pairs of encoders and decoders cascaded inside one another.

The Retina Blood Vessel dataset is used in medical imaging and computer vision for detecting and segmenting blood vessels in retinal images. These datasets are valuable for developing algorithms that assist in diagnosing and monitoring retinal and systemic diseases like diabetic retinopathy, glaucoma and hypertension.

We have 80 pairs of training images and their masks and 20 pairs of testing images and their masks. The task is to create a model and train it so that when the retinal image is fed to the trained model it will be able to generate the corresponding mask.

The following porting of this post contains the step by step guide for retinal image segmentation using Pytorch.

Downloading The Data From Kaggle and Saving It Locally

import kagglehub
import os

# Download latest version
path = kagglehub.dataset_download("abdallahwagih/retina-blood-vessel")

print("Path to dataset files:", path)

import shutil
# Specify the destination folder
destination_folder = os.path.expanduser("/content/drive/MyDrive/Colab Notebooks/Computer Vision/Retina/")

# Create the destination folder if it doesn't exist
os.makedirs(destination_folder, exist_ok=True)

# Move the downloaded files to the "downloads" folder
for item in os.listdir(path):
    source = os.path.join(path, item)
    destination = os.path.join(destination_folder, item)
    shutil.move(source, destination)

print("Files moved to:", destination_folder)

Import

import os
import cv2
import time
import random
import numpy as np
from glob import glob
from tqdm import tqdm
from operator import add
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

Dataset Class

The DriveDataset class is a custom PyTorch Dataset designed for loading and preprocessing images and their corresponding segmentation masks. Here’s a brief explanation of its components:

Initialization (__init__ method):

  • Takes two inputs: images_path (list of file paths to the images) and masks_path (list of file paths to the corresponding masks).
  • Stores the paths and calculates the total number of samples (n_samples).

Data Loading (__getitem__ method):

  • Reads an image from the specified file path using OpenCV in color format (cv2.IMREAD_COLOR).
  • Normalizes the image pixel values to the range [0, 1] by dividing by 255.
  • Rearranges the image dimensions to channel-first format (C, H, W) for compatibility with PyTorch.
  • Converts the image to a PyTorch tensor (torch.from_numpy).
  • Reads the corresponding mask in grayscale format (cv2.IMREAD_GRAYSCALE).
  • Normalizes the mask pixel values to the range [0, 1].
  • Expands the mask dimensions to add a channel dimension, converting it to shape (1, H, W).
  • Converts the mask to a PyTorch tensor.

Dataset Size (__len__ method):

  • Returns the total number of samples in the dataset.
class DriveDataset(Dataset):
  def __init__(self,images_path,masks_path):
    self.images_path = images_path
    self.masks_path = masks_path
    self.n_samples = len(images_path)

  def __getitem__(self,index):
    #reading images
    image = cv2.imread(self.images_path[index],cv2.IMREAD_COLOR)
    image = image/255
    image = np.transpose(image,(2,0,1))
    image = image.astype(np.float32)
    image = torch.from_numpy(image)

    #reading masks
    mask = cv2.imread(self.masks_path[index],cv2.IMREAD_GRAYSCALE)
    mask = mask/255
    mask = np.expand_dims(mask,axis=0)
    mask = mask.astype(np.float32)
    mask = torch.from_numpy(mask)

    return image,mask

  def __len__(self):
    return self.n_samples

Convolution Layer Class

The conv_block class is a PyTorch nn.Module that implements a basic convolutional block commonly used in neural networks for feature extraction. Here’s a concise breakdown:

Initialization (__init__ method):

  • Takes two inputs:
    • in_c: Number of input channels.
    • out_c: Number of output channels.
  • Creates a 2D convolutional layer (nn.Conv2d) with a kernel size of 3 and padding of 1 (preserves input dimensions).
  • Defines a ReLU activation function (nn.ReLU) to introduce non-linearity.

Forward Pass (forward method):

  • Applies the convolutional layer (self.conv1) to the input tensor.
  • Passes the result through the ReLU activation function (self.relu).
  • Returns the activated output.
class conv_block(nn.Module):
  def __init__(self,in_c,out_c):
    super().__init__()

    self.conv1 = nn.Conv2d(in_c,out_c,kernel_size=3,padding=1)
    self.relu = nn.ReLU()


  def forward(self,inputs):
    x = self.conv1(inputs)
    x = self.relu(x)
    return x

The conv_block will serve fundamental component in Encoder and Decoder of U-net.

The Encoder Block

The encoder_block class is a PyTorch nn.Module that implements an encoding block, commonly used in encoder-decoder architectures like U-Net. Here’s a breakdown:

  1. Initialization (__init__ method):
    • Takes two inputs:
      • in_c: Number of input channels.
      • out_c: Number of output channels.
    • Creates:
      • A conv_block (defined earlier) for feature extraction, which applies a convolution followed by a ReLU activation.
      • A max pooling layer (nn.MaxPool2d) with a pool size of (2, 2) to downsample the spatial dimensions by a factor of 2.
  2. Forward Pass (forward method):
    • Processes the input through the convolutional block (self.conv) to extract features.
    • Applies the max pooling operation (self.pool) to reduce the spatial dimensions of the extracted features.
    • Returns two outputs:
      • x: The feature map after the convolution block (used for skip connections in U-Net).
      • p: The downsampled feature map after pooling (used as input to the next layer in the encoder).
class encoder_block(nn.Module):
  def __init__(self,in_c,out_c):
    super().__init__()
    self.conv = conv_block(in_c,out_c)
    self.pool = nn.MaxPool2d((2,2))

  def forward(self,inputs):
    x = self.conv(inputs)
    p = self.pool(x)

    return x,p

This block is used in the encoder part of U-Net architecture to progressively reduce the spatial resolution while capturing hierarchical features.The feature map x is often used in skip connections, while p is passed to the next layer in the encoder.

The Decoder Block

The decoder_block class is a PyTorch module used in neural networks, particularly in encoder-decoder architectures like UNet. Here’s a brief description:

  1. Upsampling (self.up): Uses a transposed convolution (nn.ConvTranspose2d) to upsample the input tensor (inputs) from lower resolution to a higher resolution. This increases the spatial dimensions of the tensor.
  2. Concatenation: Combines the upsampled tensor (x) with a skip connection tensor (skip) along the channel dimension (axis=1). This helps in recovering spatial information lost during downsampling in the encoder part of the network.
  3. Convolution (self.conv): Applies a convolutional block (conv_block) on the concatenated tensor to learn features from the combined input. The conv_block typically consists of layers like convolution, normalization, and activation.
class decoder_block(nn.Module):
  def __init__(self,in_c,out_c):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_c,out_c,kernel_size=2,stride=2,padding=0)
    self.conv = conv_block(out_c+out_c,out_c)

  def forward(self,inputs,skip):
    x = self.up(inputs)
    x = torch.cat([x,skip],axis=1)
    x = self.conv(x)
    return x

The decoder_block is designed to gradually reconstruct the spatial resolution and feature details in the decoding phase of a neural network, integrating features from corresponding encoder layers via skip connections. This is commonly used in tasks like image segmentation or reconstruction.

The U-Net Architecture

The build_unet class defines a PyTorch implementation of a U-Net model, commonly used in image segmentation tasks. Here’s a breakdown:

  1. Encoder:
    • Composed of encoder_blocks (self.e1, self.e2) that extract hierarchical features.
    • Each block outputs two tensors: a skip connection tensor (s) and a downsampled tensor (p).
  2. Bottleneck:
    • A conv_block (self.b) is used to learn features at the lowest resolution (the “bottleneck”).
  3. Decoder:
    • Contains decoder_blocks (self.d3, self.d4) to upsample the bottleneck features and merge them with corresponding skip connections from the encoder.
  4. Classifier:
    • A 1×1 convolution layer (self.outputs) to map the final feature maps to the desired number of output channels (e.g., 1 for binary segmentation).
class build_unet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        
        # Bottleneck
        self.b = conv_block(128,256)

        # Decoder
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        # Classifier
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        # Encoder
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        

        # Bottleneck
        b = self.b(p2)

        # Decoder
        d3 = self.d3(b, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs

Dice Coefficient

The DiceLoss class implements the Dice coefficient-based loss function, which is widely used for evaluating and optimizing models in segmentation tasks. Here’s a detailed breakdown:

  1. Initialization:
  • Inherits from nn.Module.
  • Allows optional weight for class balancing and size_average, though they are not actively used in this implementation.

2. Forward Method:

  • Activation:
    • Applies the sigmoid function to the input (inputs) to map raw predictions to a range of [0, 1].
    • This step is optional and should be commented out if the model’s output already includes a sigmoid layer.
  • Flattening:
    • Flattens both inputs and targets tensors into 1D arrays to simplify computation.
  • Intersection:
    • Calculates the intersection between predictions and ground-truth (inputs * targets), summed over all elements.
  • Dice Coefficient:
    • Uses the formula: \text{Dice} = \frac{2 \times (\text{Intersection}) + \text{Smooth}}{\text{Sum of Predictions} + \text{Sum of Targets} + \text{Smooth}}
  • A smoothing constant (smooth) is added to prevent division by zero and stabilize training.
  • Loss:
    • Returns 1−Dice, as higher Dice similarity corresponds to better predictions, and the loss needs to decrease during training.


class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

The Dice loss focuses on overlap between predicted and ground-truth segmentation masks, making it especially suitable for imbalanced datasets, where the foreground (object of interest) is much smaller than the background.

Dice Binary Cross Entropy Coefficient

The DiceBCELoss class is a combination loss function that merges the Dice loss with Binary Cross Entropy (BCE) loss. This combination leverages the strengths of both loss functions, making it well-suited for binary segmentation tasks.

Structure and Components

  1. Initialization:
    • Inherits from nn.Module.
    • Allows optional weight and size_average parameters, though they are not actively used here.
  2. Forward Method:
    • Activation:
      • Applies the sigmoid function to raw model outputs (inputs) to map predictions to probabilities in [0, 1].
      • This step should be skipped if the model already includes a sigmoid activation layer.
    • Flattening:
      • Converts inputs and targets into 1D tensors to simplify element-wise operations.
    • Dice Loss:
      • Calculates the Dice coefficient-based loss, which penalizes poor overlap between predictions and targets:\text{Dice Loss} = 1 - \frac{2 \times (\text{Intersection}) + \text{Smooth}}{\text{Sum of Predictions} + \text{Sum of Targets} + \text{Smooth}}
      • A smoothing constant (smooth) is included to avoid division by zero.
    • Binary Cross Entropy (BCE):
      • Computes the standard BCE loss, which measures the difference between predicted probabilities and ground-truth labels.
    • Combined Loss:
      • Adds Dice loss and BCE loss: Combined Loss=BCE+Dice Loss

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

This combined loss function works well in cases where using Dice loss alone might lead to unstable training, especially in early epochs. It balances the strengths of both components for robust segmentation performance.

Few Helper Function

Function to create seeds

def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

Function to create directory if it does not exist

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

Function to calculate time for each epoch

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

Function to Train the Data

The train function is a typical training loop for a PyTorch model. Here’s a concise explanation:

  1. Inputs:
  • model: The neural network being trained.
  • loader: The data loader providing batches of input (x) and target (y) tensors.
  • optimizer: The optimizer used for updating the model’s parameters (e.g., Adam, SGD).
  • loss_fn: The loss function to measure prediction error (e.g., MSELoss, CrossEntropyLoss).
  • device: The device (CPU or GPU) where computations will occur.

2. Process:

  • Initialization: Sets epoch_loss to 0 to accumulate loss over the epoch.
  • Training Mode: Puts the model in training mode (model.train()), enabling layers like dropout or batch normalization.
  • Batch Loop:
    • Transfers inputs (x) and targets (y) to the specified device.
    • Clears previous gradients (optimizer.zero_grad()).
    • Makes predictions (y_pred) by passing x through the model.
    • Computes the loss between predictions and targets (loss_fn(y_pred, y)).
    • Backpropagates gradients (loss.backward()).
    • Updates the model parameters (optimizer.step()).
    • Accumulates the batch loss into epoch_loss.
  • Epoch Loss: Divides the total loss by the number of batches (len(loader)) for the average epoch loss.

3. Output:

  • Returns the average loss for the epoch, which is a measure of the model’s performance during training.
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

Function to Evaluate the model

The evaluate function is used to assess the performance of a PyTorch model on a validation or test dataset without updating the model’s parameters. Here’s a brief explanation:

Inputs:

  • model: The trained or partially trained neural network to evaluate.
  • loader: The data loader providing batches of input (x) and target (y) tensors.
  • loss_fn: The loss function to compute prediction errors (e.g., MSELoss, CrossEntropyLoss).
  • device: The device (CPU or GPU) where computations will take place.

Process:

  • Initialization: Sets epoch_loss to 0 for accumulating the loss over all batches.
  • Evaluation Mode: Puts the model in evaluation mode (model.eval()), disabling layers like dropout and batch normalization updates.
  • No Gradient Computation: Wraps the evaluation process in torch.no_grad() to save memory and computation by not calculating gradients.
  • Batch Loop:
    • Transfers inputs (x) and targets (y) to the specified device.
    • Passes inputs through the model to generate predictions (y_pred).
    • Computes the loss between predictions and targets using the loss function.
    • Accumulates the batch loss into epoch_loss.
  • Epoch Loss: Divides the total loss by the number of batches (len(loader)) to get the average loss for the epoch.

Output:

  • Returns the average loss for the epoch, which indicates how well the model performs on the given dataset.
def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss / len(loader)
    return epoch_loss

Get Data Paths

# Seeding
seeding(42)

# Create files directory to store checkpoint file
create_dir("files")
path = "/content/drive/MyDrive/Colab Notebooks/Computer Vision/Retina"
#  Get data paths
train_x = sorted(glob(path+"/Data/train/image/*"))
train_y = sorted(glob(path+"/Data/train/mask/*"))

valid_x = sorted(glob(path+"/Data/test/image/*"))
valid_y = sorted(glob(path+"/Data/test/mask/*"))

data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
print(data_str)

Set Hyper Parameters

H = 512
W = 512
size = (H, W)
batch_size = 2
lr = 1e-4
checkpoint_path = "retina_checkpoint.pth"

Load Dataset

train_dataset = DriveDataset(train_x, train_y)
valid_dataset = DriveDataset(valid_x, valid_y)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

Instantiate U-net

This code sets up the training environment for a U-Net model, including the device, model, optimizer, learning rate scheduler, and loss function.

# Set cuda device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build model
model = build_unet()
model = model.to(device)

# Set Optimizer and Loss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()

Load the trained model if exists

try:
  # Load the model's state dictionary
  model.load_state_dict(torch.load(checkpoint_path))
except Exception as e:
    print(e)

Model Training

This code snippet implements a training loop for a PyTorch model with validation and checkpointing. Here’s a concise explanation:

Initialization:

  • best_valid_loss: Set to a high initial value (float("inf")) to track the best validation loss so far.
  • num_epochs: Number of epochs for training.
  • start_time: Captures the start time of the epoch for timing purposes.

Training and Validation:

  • Training Loss: Calls the train function to perform backpropagation and update the model parameters using the training dataset (train_loader).
  • Validation Loss: Calls the evaluate function to assess the model’s performance on the validation dataset (valid_loader).

Model Checkpointing:

  • If the validation loss improves (i.e., valid_loss < best_valid_loss):
    • Logs the improvement.
    • Saves the model’s state to the specified checkpoint_path using torch.save.
    • Updates best_valid_loss to the new valid_loss.

Epoch Timing:

  • Measures and formats the time taken for the epoch using the epoch_time function.

Logging Results:

  • Logs the epoch number, time taken, training loss, and validation loss in a formatted string.
best_valid_loss = float("inf")
num_epochs = 3

for epoch in range(num_epochs):
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)

    # Saving the model
    if valid_loss < best_valid_loss:
        print(f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}")

        best_valid_loss = valid_loss
        torch.save(model.state_dict(), checkpoint_path)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
    data_str += f'\tTrain Loss: {train_loss:.3f}\n'
    data_str += f'\tVal. Loss: {valid_loss:.3f}\n'
    print(data_str)

Helper Functions for Testing

Function to Calculate Metric

def calculate_metrics(y_true, y_pred):
    # Ground truth
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    # Prediction
    y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    score_jaccard = jaccard_score(y_true, y_pred)
    score_f1 = f1_score(y_true, y_pred)
    score_recall = recall_score(y_true, y_pred)
    score_precision = precision_score(y_true, y_pred)
    score_acc = accuracy_score(y_true, y_pred)

    return [score_jaccard, score_f1, score_recall, score_precision, score_acc]

Function to parse the mask

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)    ## (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  ## (512, 512, 3)
    return mask

Get Test Data

#  Seeding
seeding(42)

# Folders
create_dir("results")

# Load dataset
test_x = sorted(glob(path+"/Data/test/image/*"))
test_y = sorted(glob(path+"/Data/test/mask/*"))

Load Checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_unet()
model = model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

Test

This code evaluates a segmentation model on a test dataset and saves the results with metrics. Here’s a step-by-step explanation:

Initialization:

  • metrics_score: A list to accumulate metrics for evaluation (e.g., accuracy, precision, recall, F1 score, IoU).
  • time_taken: List to store inference time for each image, useful for FPS calculation.

Loop Over Test Dataset:

  • Iterates through test_x (image paths) and test_y (mask paths) using enumerate and tqdm for progress tracking.

Preprocessing:

  • Input Image:
    • Reads the image (cv2.imread).
    • Normalizes pixel values to [0, 1].
    • Transposes dimensions from (H, W, C) to (C, H, W) for PyTorch compatibility.
    • Expands to (1, C, H, W) to include batch dimension and converts to a torch.Tensor.
    • Moves the tensor to the computation device.
  • Ground Truth Mask:
    • Reads the mask as a grayscale image.
    • Normalizes pixel values to [0, 1].
    • Reshapes to (1, 1, H, W) to include channel and batch dimensions.
    • Converts to torch.Tensor and moves to the computation device.

Model Prediction:

  • Uses torch.no_grad() to disable gradient computation during inference.
  • Records the time taken for prediction to calculate FPS later.
  • Applies the model and computes the sigmoid activation to get pixel-wise probabilities.
  • Thresholds the prediction to binary output (> 0.5).

Metrics Calculation:

  • Calls calculate_metrics to evaluate the predictions against the ground truth mask.
  • Accumulates scores using map(add, metrics_score, score).

Saving Results:

  • Converts predicted and ground truth masks to visually interpretable formats (mask_parse).
  • Creates a concatenated image showing:
    • Original image.
    • Ground truth mask.
    • Predicted mask.
  • Saves the concatenated image using cv2.imwrite.
metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []

for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
    # Extract the name
    name = x.split("/")[-1].split(".")[0]

    # Reading image
    image = cv2.imread(x, cv2.IMREAD_COLOR) ## (512, 512, 3)
    x = np.transpose(image, (2, 0, 1))      ## (3, 512, 512)
    x = x / 255.0
    x = np.expand_dims(x, axis=0)           ## (1, 3, 512, 512)
    x = x.astype(np.float32)
    x = torch.from_numpy(x)
    x = x.to(device)

    # Reading mask
    mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    y = np.expand_dims(mask, axis=0)            ## (1, 512, 512)
    y = y / 255.0
    y = np.expand_dims(y, axis=0)               ## (1, 1, 512, 512)
    y = y.astype(np.float32)
    y = torch.from_numpy(y)
    y = y.to(device)

    with torch.no_grad():
        # Prediction and Calculating FPS
        start_time = time.time()
        pred_y = model(x)
        pred_y = torch.sigmoid(pred_y)
        total_time = time.time() - start_time
        time_taken.append(total_time)

        score = calculate_metrics(y, pred_y)
        metrics_score = list(map(add, metrics_score, score))
        pred_y = pred_y[0].cpu().numpy()        ## (1, 512, 512)
        pred_y = np.squeeze(pred_y, axis=0)     ## (512, 512)
        pred_y = pred_y > 0.5
        pred_y = np.array(pred_y, dtype=np.uint8)

    # Saving masks
    ori_mask = mask_parse(mask)
    pred_y = mask_parse(pred_y)
    line = np.ones((size[1], 10, 3)) * 128

    # Concatenate images
    cat_images = np.concatenate(
        [image, line, ori_mask, line, pred_y * 255], axis=1
    )

    # Add labels to the images
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    font_color = (255, 255, 255)
    thickness = 2
    line_type = cv2.LINE_AA

    # Add text to the original image
    cv2.putText(cat_images, "Original", (10, 50), font, font_scale, font_color, thickness, line_type)

    # Add text to the ground truth mask
    cv2.putText(cat_images, "Ground Truth", (image.shape[1] + 20, 50), font, font_scale, font_color, thickness, line_type)

    # Add text to the predicted mask
    cv2.putText(cat_images, "Prediction", (2 * image.shape[1] + 40, 50), font, font_scale, font_color, thickness, line_type)

    # Save the resulting image
    cv2.imwrite(f"results/{name}.png", cat_images)

Test Evaluation

This snippet calculates and prints evaluation metrics and the inference speed (frames per second, FPS) of the model.

jaccard = metrics_score[0] / len(test_x)
f1 = metrics_score[1] / len(test_x)
recall = metrics_score[2] / len(test_x)
precision = metrics_score[3] / len(test_x)
acc = metrics_score[4] / len(test_x)
print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f}")

fps = 1 / np.mean(time_taken)
print("FPS: ", fps)

Show Result

import matplotlib.pyplot as plt

img = plt.imread(path+'/results/4.png')
plt.figure(figsize=(15, 8))
plt.imshow(img)
plt.axis('off')

plt.show()

Looking How our Model looks

from torchview import draw_graph
# Instantiate the model
unet = build_unet()

# Define an input tensor
input_tensor = torch.randn(1, 3, 224, 224)  # Example input tensor with batch size 1 and 224x224 image

# Draw the graph
model_graph = draw_graph(unet, input_data=input_tensor, expand_nested=True)

# Render and save the graph using graphviz
model_graph.visual_graph.render(filename="unet_model_graph", format="png", cleanup=True)

model_graph.visual_graph

Alternative way

!pip3 install onnx
# Export model to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(unet, dummy_input, path+"unet.onnx", opset_version=11)

Open the “.onnx” file in Netron