SRGAN
SRGAN
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt
# Upsampling layers
self.upsample1 = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2), # Upsample by factor of 2
nn.LeakyReLU(0.2, inplace=True)
)
self.upsample2 = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2), # Upsample by factor of 2
nn.LeakyReLU(0.2, inplace=True)
)
out = self.upsample1(out)
out = self.upsample2(out)
out = self.conv_hr(out)
return out
# Super-resolve image
with torch.no_grad():
sr_img = model(img)
plt.figure(figsize=(12, 6))
# Low-resolution image
plt.subplot(1, 2, 1)
plt.title('Low Resolution')
plt.imshow(lr_img)
# Super-resolved image
plt.subplot(1, 2, 2)
plt.title('Super-Resolved')
plt.imshow(sr_image)
plt.show()
import torch
from torch import nn
class ConvBLock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
discrininator=False,
use_act=True,
use_bn=True,
**kwargs,
):
super ().__init__()
self.cnn = nn.Conv2d(in_channels,out_channels, **kwargs, bias=not
use_bn)
self.bn = nn.BatchNorm2d(out_channels) if use_bn else
nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discrininator
else nn.PReLU(num_parameters=out_channels)
)
def forward(self, x):
return self.act(self.bn(self.cnn(x))) if self.use_act else
self.bn(self.cnn(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_c, scale_factor):
super().__init__()
self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1 , 1)
self.ps = nn.PixelShuffle(scale_factor) # in_c * 4, H, I -=>
in,c, H×2, W*2
self.act = nn. PReLU(nun_paraneters=in_c)
def forward(self, x):
return self.act(self.ps(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.block1 = ConvBlock(
in_channels,
in_channels,
kerneL_size=3,
stride=1,
padding=1
)
self.bLock2 = ConvBlock(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=False,
)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
return out + x
def test():
low_resolution = 24 # 96x96 → 24x24 with torch.cuda.amp.autocast):
with torch.cuda.amp.autocast():
x = torch.randn((5, 3, low_resolution, low_resolution))
gen = Generator()
gen_out = gen(x)
disc = Discriminator()
disc_out = disc(gen_out)
print (gen_out.shape)
print(disc_out.shape)