Image Super-Resolution with ESRGAN

Enhance image resolution by 4-8x using Enhanced Super-Resolution Generative Adversarial Networks. Transform low-resolution images into high-quality outputs with realistic details.

View Source Run on Colab

Project Overview

This project implements Single Image Super-Resolution (SISR) using the ESRGAN (Enhanced Super-Resolution GAN) architecture. Given a low-resolution image, the model generates a high-resolution version with enhanced textures and details that appear photo-realistic. Unlike traditional upscaling methods like bicubic interpolation that produce blurry results, ESRGAN leverages deep learning and perceptual loss functions to create sharp, natural-looking images.

4-8x Upscaling

Increase image resolution by up to 8 times while maintaining quality

Detail Enhancement

Generate realistic textures and fine details that weren't in the original

GPU Accelerated

Fast inference using CUDA acceleration on Google Colab

What is SRGAN?

SRGAN (Super-Resolution Generative Adversarial Network) was introduced in the landmark 2017 paper "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" by Ledig et al. It revolutionized image upscaling by using adversarial training to generate photo-realistic textures instead of just minimizing pixel-wise reconstruction error.

Key Innovation: Perceptual Loss

Traditional super-resolution methods optimize for pixel-by-pixel accuracy using MSE (Mean Squared Error), which leads to overly smooth, blurry outputs. SRGAN introduces a perceptual loss function that compares high-level features extracted from a pre-trained VGG network. This encourages the generator to produce images that match the perceptual characteristics of real high-resolution images, resulting in sharper edges and more realistic textures.

ESRGAN: Enhanced SRGAN

ESRGAN (Enhanced SRGAN), published in 2018 by Wang et al., improves upon the original SRGAN architecture with three key enhancements:

ESRGAN Improvements

Architecture Overview

ESRGAN follows the classic GAN paradigm with a Generator network that creates high-resolution images and a Discriminator that tries to distinguish them from real images:

Low-Res Image
Generator (RRDB)
High-Res Image
Real HR Image
Discriminator
Generated HR
Real or Fake?

Generator Network (RRDB-Net)

The generator uses Residual-in-Residual Dense Blocks (RRDB), which combine residual learning with dense connections to enable deeper networks without vanishing gradients:

import torch.nn as nn

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(RRDBNet, self).__init__()

        # First convolution layer
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)

        # RRDB blocks (typically 23 blocks)
        self.RRDB_trunk = make_layer(RRDB, nb, nf=nf, gc=gc)

        # Trunk convolution
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        # Upsampling layers (4x = two 2x upsample blocks)
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        # Final high-resolution convolutions
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk  # Global residual connection

        # Upsample 2x
        fea = self.lrelu(self.upconv1(nn.functional.interpolate(
            fea, scale_factor=2, mode='nearest')))
        # Upsample 2x again (total 4x)
        fea = self.lrelu(self.upconv2(nn.functional.interpolate(
            fea, scale_factor=2, mode='nearest')))

        out = self.conv_last(self.lrelu(self.HRconv(fea)))
        return out

Discriminator Network

The discriminator is a VGG-style convolutional network that classifies images as real (original high-resolution) or fake (generated by the generator):

class VGGStyleDiscriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(VGGStyleDiscriminator, self).__init__()

        # Conv blocks with increasing channels: 64 → 128 → 256 → 512
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 64, 3, 1, 1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(64, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(128, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, True),

            # ... more layers up to 512 channels
        )

        # Fully connected layers for binary classification
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        return self.classifier(features)

Loss Functions

ESRGAN combines multiple loss functions to achieve both pixel-accurate reconstruction and perceptual realism:

Loss Type Mathematical Form Purpose
Perceptual Loss \(L_{percep} = \frac{1}{W_i H_i} \sum ||VGG(I^{HR}) - VGG(G(I^{LR}))||^2\) Matches feature representations from VGG19 network
Adversarial Loss \(L_{G} = -\mathbb{E}_{x}[\log(D(G(x)))]\) Encourages generator to fool the discriminator
L1 Loss \(L_{1} = ||I^{HR} - G(I^{LR})||_1\) Ensures pixel-level similarity to ground truth
Total Loss \(L_{total} = L_{percep} + \lambda L_{G} + \eta L_{1}\) Weighted combination of all losses

Loss Function Weighting

The perceptual loss typically dominates with \(\lambda = 5 \times 10^{-3}\) for adversarial loss and \(\eta = 1 \times 10^{-2}\) for L1 loss. These weights balance realism (perceptual), fooling the discriminator (adversarial), and pixel accuracy (L1). Fine-tuning these hyperparameters is crucial for optimal results.

Implementation Steps

  1. Setup Environment

    Install PyTorch and clone the ESRGAN repository with pre-trained models:

    # Install dependencies
    !pip install torch torchvision opencv-python
    
    # Clone ESRGAN repo
    !git clone https://github.com/xinntao/ESRGAN
    %cd ESRGAN
  2. Download Pre-trained Weights

    Load the RRDB_ESRGAN_x4 model trained on the DIV2K dataset:

    import torch
    
    # Download pre-trained model
    !wget https://github.com/xinntao/ESRGAN/releases/download/v0.0/RRDB_ESRGAN_x4.pth
    
    # Load model architecture
    from RRDBNet_arch import RRDBNet
    
    model = RRDBNet(3, 3, 64, 23, gc=32)
    model.load_state_dict(torch.load('RRDB_ESRGAN_x4.pth'), strict=True)
    model.eval()
    model = model.cuda()  # Move to GPU
  3. Prepare Input Image

    Upload and preprocess your low-resolution image:

    import cv2
    import numpy as np
    from google.colab import files
    
    # Upload image
    uploaded = files.upload()
    img_path = list(uploaded.keys())[0]
    
    # Read and normalize to [0, 1]
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    img = img.astype(np.float32) / 255.0
    
    # Convert BGR to RGB and transpose to CHW format
    img = torch.from_numpy(
        np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
    ).float()
    
    # Add batch dimension and move to GPU
    img_LR = img.unsqueeze(0).cuda()
  4. Run Super-Resolution

    Generate the high-resolution output using the trained model:

    with torch.no_grad():
        output = model(img_LR).data.squeeze().float().cpu()
        output = output.clamp_(0, 1).numpy()
    
    # Convert back to HWC format and BGR color space
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
    output = (output * 255.0).round().astype(np.uint8)
    
    # Save and download result
    cv2.imwrite('output_ESRGAN.png', output)
    files.download('output_ESRGAN.png')
  5. Compare Results

    Visualize the original low-resolution and upscaled high-resolution images side by side:

    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Original low-resolution
    axes[0].imshow(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB))
    axes[0].set_title('Low Resolution')
    axes[0].axis('off')
    
    # Super-resolved high-resolution
    axes[1].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
    axes[1].set_title('ESRGAN 4x Upscaled')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

Performance Evaluation

ESRGAN produces significantly sharper and more realistic results compared to traditional interpolation methods and even the original SRGAN:

Method PSNR (dB) SSIM Perceptual Quality
Bicubic Interpolation 28.42 0.8104 Blurry, lacks detail
SRGAN 26.02 0.7397 Sharp but artifacts
ESRGAN 26.68 0.7720 Best visual quality

Understanding the Metrics

Note: ESRGAN may have lower PSNR than bicubic interpolation but produces far more realistic textures that human observers prefer. This is the key insight of perceptual loss functions.

Applications

ESRGAN-based super-resolution has numerous real-world applications:

Photo Restoration

Enhance old or low-quality photographs

Video Upscaling

Convert SD content to HD/4K quality

Medical Imaging

Improve resolution of MRI and CT scans

Satellite Imagery

Enhance remote sensing data quality

Run the Complete Implementation

Experience ESRGAN super-resolution with GPU acceleration on Google Colab. Upload your own images and see the transformation in real-time.

References & Further Reading