isr / srcnn.py
dragonSwing's picture
Add application files
0fe2a53
raw
history blame
3.55 kB
from typing import Union
import cv2
import torch
import numpy as np
from torch import nn
from torchvision import transforms as T
class SRCNN(nn.Module):
def __init__(
self,
input_channels=3,
output_channels=3,
input_size=33,
label_size=21,
scale=2,
device=None,
):
super().__init__()
self.input_size = input_size
self.label_size = label_size
self.pad = (self.input_size - self.label_size) // 2
self.scale = scale
self.model = nn.Sequential(
nn.Conv2d(input_channels, 64, 9),
nn.ReLU(),
nn.Conv2d(64, 32, 1),
nn.ReLU(),
nn.Conv2d(32, output_channels, 5),
nn.ReLU(),
)
self.transform = T.Compose(
[T.ToTensor()] # Scale between [0, 1]
)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
@torch.no_grad()
def pre_process(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if torch.is_tensor(x):
return x / 255.0
else:
return self.transform(x)
@torch.no_grad()
def post_process(self, x: torch.Tensor) -> torch.Tensor:
return x.clip(0, 1) * 255.0
@torch.no_grad()
def enhance(self, image: np.ndarray, outscale: float = 2) -> np.ndarray:
(h, w) = image.shape[:2]
scale_w = int((w - w % self.label_size + self.input_size) * self.scale)
scale_h = int((h - h % self.label_size + self.input_size) * self.scale)
# resize the input image using bicubic interpolation
scaled = cv2.resize(image, (scale_w, scale_h), interpolation=cv2.INTER_CUBIC)
# Preprocessing
in_tensor = self.pre_process(scaled) # (C, H, W)
out_tensor = torch.zeros_like(in_tensor) # (C, H, W)
# slide a window from left-to-right and top-to-bottom
for y in range(0, scale_h - self.input_size + 1, self.label_size):
for x in range(0, scale_w - self.input_size + 1, self.label_size):
# crop ROI from our scaled image
crop = in_tensor[:, y : y + self.input_size, x : x + self.input_size]
# make a prediction on the crop and store it in our output
crop_inp = crop.unsqueeze(0).to(self.device)
pred = self.forward(crop_inp).cpu().squeeze()
out_tensor[
:,
y + self.pad : y + self.pad + self.label_size,
x + self.pad : x + self.pad + self.label_size,
] = pred
out_tensor = self.post_process(out_tensor)
output = out_tensor.permute(1, 2, 0).numpy() # (C, H, W) to (H, W, C)
output = output[self.pad : -self.pad * 2, self.pad : -self.pad * 2]
output = np.clip(output, 0, 255).astype("uint8")
# Use openCV to upsample image if scaling factor different than 2
if outscale != 2:
interpolation = cv2.INTER_AREA if outscale < 2 else cv2.INTER_LANCZOS4
h, w = output.shape[0:2]
output = cv2.resize(
output,
(int(w * outscale / 2), int(h * outscale / 2)),
interpolation=interpolation,
)
return output, None