|
from typing import Union
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
from torchvision import transforms as T
|
|
|
|
|
|
class SRCNN(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_channels=3,
|
|
output_channels=3,
|
|
input_size=33,
|
|
label_size=21,
|
|
scale=2,
|
|
device=None,
|
|
):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.label_size = label_size
|
|
self.pad = (self.input_size - self.label_size) // 2
|
|
self.scale = scale
|
|
self.model = nn.Sequential(
|
|
nn.Conv2d(input_channels, 64, 9),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 32, 1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, output_channels, 5),
|
|
nn.ReLU(),
|
|
)
|
|
self.transform = T.Compose(
|
|
[T.ToTensor()]
|
|
)
|
|
|
|
if device is None:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.device = device
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.model(x)
|
|
|
|
@torch.no_grad()
|
|
def pre_process(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
if torch.is_tensor(x):
|
|
return x / 255.0
|
|
else:
|
|
return self.transform(x)
|
|
|
|
@torch.no_grad()
|
|
def post_process(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x.clip(0, 1) * 255.0
|
|
|
|
@torch.no_grad()
|
|
def enhance(self, image: np.ndarray, outscale: float = 2) -> np.ndarray:
|
|
(h, w) = image.shape[:2]
|
|
scale_w = int((w - w % self.label_size + self.input_size) * self.scale)
|
|
scale_h = int((h - h % self.label_size + self.input_size) * self.scale)
|
|
|
|
scaled = cv2.resize(image, (scale_w, scale_h), interpolation=cv2.INTER_CUBIC)
|
|
|
|
in_tensor = self.pre_process(scaled)
|
|
out_tensor = torch.zeros_like(in_tensor)
|
|
|
|
|
|
for y in range(0, scale_h - self.input_size + 1, self.label_size):
|
|
for x in range(0, scale_w - self.input_size + 1, self.label_size):
|
|
|
|
crop = in_tensor[:, y : y + self.input_size, x : x + self.input_size]
|
|
|
|
crop_inp = crop.unsqueeze(0).to(self.device)
|
|
pred = self.forward(crop_inp).cpu().squeeze()
|
|
out_tensor[
|
|
:,
|
|
y + self.pad : y + self.pad + self.label_size,
|
|
x + self.pad : x + self.pad + self.label_size,
|
|
] = pred
|
|
|
|
out_tensor = self.post_process(out_tensor)
|
|
output = out_tensor.permute(1, 2, 0).numpy()
|
|
output = output[self.pad : -self.pad * 2, self.pad : -self.pad * 2]
|
|
output = np.clip(output, 0, 255).astype("uint8")
|
|
|
|
|
|
if outscale != 2:
|
|
interpolation = cv2.INTER_AREA if outscale < 2 else cv2.INTER_LANCZOS4
|
|
h, w = output.shape[0:2]
|
|
output = cv2.resize(
|
|
output,
|
|
(int(w * outscale / 2), int(h * outscale / 2)),
|
|
interpolation=interpolation,
|
|
)
|
|
|
|
return output, None
|
|
|