Python
Low Light Image Enhancement Using Zero-DCE

Low Light Image Enhancement Using Zero-DCE

Low light image enhancement is a computer vision technique aimed at improving the visibility and quality of images captured in dim or poorly lit environments. This is a simple guide to enhancing low-light images using Zero-DCE. It begins with an introduction to Zero-DCE and walks you through its implementation in PyTorch.

Introduction

Zero-Reference Deep Curve Estimation (Zero-DCE) is a novel method designed for low-light image enhancement. A key distinguishing feature of Zero-DCE is its zero-reference learning strategy, meaning it does not require paired or unpaired reference images during the training process. This is made possible by training a lightweight deep network called DCE-Net using a set of carefully formulated non-reference loss functions that implicitly measure the quality of the enhanced output. The method reformulates light enhancement as an image-specific curve estimation task, where the DCE-Net estimates pixel-wise and high-order curves for dynamic range adjustment. This simple non-linear curve mapping is efficient and generalizes well to diverse lighting conditions.

The Network

The goal of this network is to predict curve parameters r, which is used to enhance the input low-light image through an iterative formula:

 x_{i+1} = x_i + r * (x^2_i - x_i)

This formula adjusts pixel values to enhance brightness and contrast.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

import numpy as np

class enhance_net_nopool(nn.Module):

	def __init__(self):
		super(enhance_net_nopool, self).__init__()

		self.relu = nn.ReLU(inplace=True)

		number_f = 32
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)

		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)



	def forward(self, x):

		x1 = self.relu(self.e_conv1(x))
		
		x2 = self.relu(self.e_conv2(x1))
		
		x3 = self.relu(self.e_conv3(x2))
		
		x4 = self.relu(self.e_conv4(x3))

		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
	
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)


		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
		x = x + r6*(torch.pow(x,2)-x)
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r

Feature Extraction Layers

x1 = self.relu(self.e_conv1(x))
x2 = self.relu(self.e_conv2(x1))
x3 = self.relu(self.e_conv3(x2))
x4 = self.relu(self.e_conv4(x3))

Each of these layers extracts increasingly abstract features from the input image. They all keep the same spatial resolution without any downsampling.

Four convolutional layers are enough to extract meaningful features without being computationally heavy. As this is not a classification task, no downsampling is done. Each pixel must be adjusted independent and precisely.

Pooling or downsampling would lose the fine details and introduce artifacts or make enhancement coarse.

The curve parameters ‘r’ are spatially variant: each pixel gets its own ‘r’ value. So, preserving full resolution is critical.

Feature Fusion

x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

The network concatenate intermediate feature maps to allow information flow across different depth levels. This improves gradient flow and feature reuse.

Curve Estimation

x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))

This layer outputs the curve parameters which are split into 8 sets of 3 channels, each (one -3 channel tensor per iteration step).

Each 3-channel group represents a set of enhancement curves to be applied to RGB channels.

Iterative Enhancement Using Curve Parameters

r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)

Each ‘r’ tensor has shape [B,3,H,W].

The enhancement function is applied 8 times.

x = x + r1*(torch.pow(x,2)-x)
x = x + r2*(torch.pow(x,2)-x)
x = x + r3*(torch.pow(x,2)-x)
enhance_image_1 = x + r4*(torch.pow(x,2)-x)
x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
x = x + r6*(torch.pow(x,2)-x)
x = x + r7*(torch.pow(x,2)-x)
enhance_image = x + r8*(torch.pow(x,2)-x)
r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
return enhance_image_1,enhance_image,r

This is a differentiable curve-based enhancement function, where:

  • x^2 – x: stretches contrast for low values
  • r : controls how aggressively this adjustment is applied.
  • The iterative process refines the image step by step.

enhance_image_1 is a partially enhanced image which is used for intermediate loss supervision.

enhance_image is a final enhanced image while r is learned enhancement curves used for loss computation and analysis.

The Loss Functions

In the Zero-Reference Deep Curve Estimation (Zero-DCE) method, a key feature is its zero-reference learning strategy, which means the network is trained without requiring any paired or unpaired reference images. To enable this, the DCE-Net is trained using a set of carefully formulated differentiable non-reference loss functions. These losses implicitly measure the enhancement quality of the output image and guide the learning process of the network.

Spatial Consistency Loss (Lspa):

This loss aims to encourage spatial coherence in the enhanced image. It works by preserving the difference of neighboring regions between the input image and its enhanced version. This is important for maintaining the local contrast of the image.

class L_spa(nn.Module):
    def __init__(self):
        super(L_spa, self).__init__()
        self.weight_left = nn.Parameter(torch.tensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).float().cuda().unsqueeze(0).unsqueeze(0), requires_grad=False)
        self.weight_right = nn.Parameter(torch.tensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).float().cuda().unsqueeze(0).unsqueeze(0), requires_grad=False)
        self.weight_up = nn.Parameter(torch.tensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]]).float().cuda().unsqueeze(0).unsqueeze(0), requires_grad=False)
        self.weight_down = nn.Parameter(torch.tensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).float().cuda().unsqueeze(0).unsqueeze(0), requires_grad=False)
        self.pool = nn.AvgPool2d(4)

    def forward(self, org, enhance):
        org_mean = torch.mean(org, 1, keepdim=True)
        enhance_mean = torch.mean(enhance, 1, keepdim=True)
        org_pool = self.pool(org_mean)
        enhance_pool = self.pool(enhance_mean)

        D_left = (F.conv2d(org_pool, self.weight_left, padding=1) - F.conv2d(enhance_pool, self.weight_left, padding=1)) ** 2
        D_right = (F.conv2d(org_pool, self.weight_right, padding=1) - F.conv2d(enhance_pool, self.weight_right, padding=1)) ** 2
        D_up = (F.conv2d(org_pool, self.weight_up, padding=1) - F.conv2d(enhance_pool, self.weight_up, padding=1)) ** 2
        D_down = (F.conv2d(org_pool, self.weight_down, padding=1) - F.conv2d(enhance_pool, self.weight_down, padding=1)) ** 2

        return D_left + D_right + D_up + D_down

Exposure Control Loss (Lexp):

Designed to prevent under- or over-exposed regions, this loss controls the overall exposure level of the enhanced image. It measures the distance between the average intensity value of a local region in the enhanced image and a predefined well-exposedness level (set to 0.6 in their experiments). Removing this loss can lead to a failure to recover low-light regions.

class L_exp(nn.Module):
    def __init__(self, patch_size, mean_val):
        super(L_exp, self).__init__()
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val

    def forward(self, x):
        mean = self.pool(torch.mean(x, 1, keepdim=True))
        return torch.mean((mean - self.mean_val) ** 2)

Color Constancy Loss (Lcol):

This loss is included to correct potential color deviations in the enhanced image and build relationships among the three adjusted RGB channels. It is inspired by the Gray-World color constancy hypothesis, which assumes that the average color in each sensor channel should be gray over the entire image. Discarding this loss results in severe color casts.

class L_color(nn.Module):
    def __init__(self):
        super(L_color, self).__init__()

    def forward(self, x):
        mean_rgb = torch.mean(x, [2, 3], keepdim=True)
        mr, mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = (mr - mg) ** 2
        Drb = (mr - mb) ** 2
        Dgb = (mb - mg) ** 2
        k = torch.sqrt(Drg ** 2 + Drb ** 2 + Dgb ** 2)
        return k

Illumination Smoothness Loss (LtvA):

To preserve the monotonic relationships between neighboring pixels and avoid artifacts, this loss adds a smoothness constraint to each curve parameter map estimated by the network. Removing this loss hampers correlations between neighboring regions and leads to obvious artifacts.

class L_TV(nn.Module):
    def __init__(self, TVLoss_weight=1):
        super(L_TV, self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self, x):
        h_x, w_x = x.size(2), x.size(3)
        count_h = (h_x - 1) * w_x
        count_w = h_x * (w_x - 1)
        h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum()
        w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum()
        return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / x.size(0)



Dataset Class

Get the images from defined path and preprocess them.

#
class LowlightDataset(Dataset):
    def __init__(self, image_dir, size=256):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
        self.size = size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img = Image.open(self.image_paths[index]).convert("RGB")
        img = img.resize((self.size, self.size), Image.Resampling.LANCZOS)
        img = np.asarray(img) / 255.0
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
        return img_tensor

Training

The training loop for Zero-DCE in PyTorch begins by initializing the enhance_net_nopool network and loading pretrained weights if specified. The model is moved to the GPU and set to training mode. A custom LowlightDataset is used to load low-light images with the PyTorch DataLoader. Several loss functions are defined, including color loss, spatial consistency loss, exposure control loss, and total variation loss, each weighted appropriately. An Adam optimizer is initialized with a learning rate and weight decay from the configuration. During each epoch, batches of low-light images are passed through the network to produce enhanced outputs. The total loss is calculated by combining the individual losses, backpropagated, and used to update the model parameters. Gradients are clipped to stabilize training. Periodically, training progress is printed, and model checkpoints are saved at specified intervals.



DCE_net = enhance_net_nopool().cuda()
if config.load_pretrain:
    DCE_net.load_state_dict(torch.load(config.pretrain_dir))

DCE_net.train()
train_dataset = LowlightDataset(config.lowlight_images_path)
train_loader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers)

L_color_loss = L_color().cuda()
L_spa_loss = L_spa().cuda()
L_exp_loss = L_exp(16, 0.6).cuda()
L_TV_loss = L_TV().cuda()


optimizer = Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)

for epoch in range(config.num_epochs):
    for iteration, img_lowlight in enumerate(train_loader):
        img_lowlight = img_lowlight.cuda()
        enhanced_1, enhanced_img, A = DCE_net(img_lowlight)

        loss_tv = 200 * L_TV_loss(A)
        loss_spa = torch.mean(L_spa_loss(img_lowlight, enhanced_img))
        loss_col = 5 * torch.mean(L_color_loss(enhanced_img))
        loss_exp = 10 * L_exp_loss(enhanced_img)

        total_loss = loss_tv + loss_spa + loss_col + loss_exp

        optimizer.zero_grad()
        total_loss.backward()
        clip_grad_norm_(DCE_net.parameters(), config.grad_clip_norm)
        optimizer.step()

        if (iteration + 1) % config.display_iter == 0:
            print(f"Epoch [{epoch+1}/{config.num_epochs}], Step [{iteration+1}], Loss: {total_loss.item():.4f}")

        if (iteration + 1) % config.snapshot_iter == 0:
            torch.save(DCE_net.state_dict(), os.path.join(config.snapshots_folder, f"Epoch{epoch+1}_Iter{iteration+1}.pth"))

Evaluation

The code first determines whether a GPU is available and sets the computation device accordingly. It then initializes the enhance_net_nopool model and moves it to the selected device (GPU or CPU). The trained model weights are loaded from a saved checkpoint file (Epoch9_Iter100.pth), ensuring compatibility with the selected device using map_location. Finally, the model is set to evaluation mode using model.eval(), which disables layers like dropout and batch normalization that behave differently during training. This prepares the model for inference on new low-light images.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = enhance_net_nopool().to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/Computer Vision/Zero_DCE_Models/Epoch100_Iter100.pth', map_location=device))
model.eval()

Testing

A test image is loaded and processed before being passed through the trained Zero-DCE model in evaluation mode. The output is the enhanced version of the input image. Finally, both the original and enhanced images are converted to NumPy arrays and displayed side by side using Matplotlib for visual comparison.

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and preprocess test image
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

img = Image.open('/content/cup.jpg').convert('RGB')  # Replace with your file path
input_img = transform(img).unsqueeze(0).to(device)  # Shape: [1, 3, H, W]

# Forward pass through the model
with torch.no_grad():
    enhance_image_1, enhance_image, r = model(input_img)

output_img = enhance_image  # final enhanced image from the model

# Convert tensors to numpy images for display
original_np = input_img.squeeze().permute(1, 2, 0).cpu().numpy()
enhanced_np = output_img.squeeze().permute(1, 2, 0).cpu().numpy()

# Plot side by side
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(original_np)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(enhanced_np)
axes[1].set_title('Enhanced Image')
axes[1].axis('off')

plt.show()

Output

Low light image enhancement using Zero-DCE

Links