cf / app.py
siyangyuan's picture
Update app.py
8b48b0b
import os
from enum import IntEnum
from pathlib import Path
from tempfile import mktemp
from typing import IO, Dict, Type
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gradio import Interface, inputs, outputs
DEVICE = "cpu"
WEIGHTS_PATH = Path(__file__).parent / "weights"
AVALIABLE_WEIGHTS = {
basename: path
for basename, ext in (
os.path.splitext(filename) for filename in os.listdir(WEIGHTS_PATH)
)
if (path := WEIGHTS_PATH / (basename + ext)).is_file() and ext.endswith("pth")
}
class ScaleMode(IntEnum):
up2x = 2
up3x = 3
up4x = 4
class TileMode(IntEnum):
full = 0
half = 1
quarter = 2
ninth = 3
sixteenth = 4
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=8, bias=False):
super(SEBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, in_channels // reduction, 1, 1, 0, bias=bias
)
self.conv2 = nn.Conv2d(
in_channels // reduction, in_channels, 1, 1, 0, bias=bias
)
def forward(self, x):
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
else:
x0 = torch.mean(x, dim=(2, 3), keepdim=True)
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = torch.sigmoid(x0)
x = torch.mul(x, x0)
return x
def forward_mean(self, x, x0):
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = torch.sigmoid(x0)
x = torch.mul(x, x0)
return x
class UNetConv(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, se):
super(UNetConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
)
if se:
self.seblock = SEBlock(out_channels, reduction=8, bias=True)
else:
self.seblock = None
def forward(self, x):
z = self.conv(x)
if self.seblock is not None:
z = self.seblock(z)
return z
class UNet1(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, se=True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x1, x2):
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet1x3(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1x3, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, se=True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x1, x2):
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet2(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet2, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 64, 128, se=True)
self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
self.conv3 = UNetConv(128, 256, 128, se=True)
self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
self.conv4 = UNetConv(128, 64, 64, se=True)
self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3(x3)
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4(x2 + x3)
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
def forward_a(self, x): # conv234结尾有se
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x2): # conv234结尾有se
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3.conv(x3)
return x3
def forward_c(self, x2, x3): # conv234结尾有se
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4.conv(x2 + x3)
return x4
def forward_d(self, x1, x4): # conv234结尾有se
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
class UpCunet2x(nn.Module): # 完美tile,全程无损
def __init__(self, in_channels=3, out_channels=3):
super(UpCunet2x, self).__init__()
self.unet1 = UNet1(in_channels, out_channels, deconv=True)
self.unet2 = UNet2(in_channels, out_channels, deconv=False)
def forward(self, x, tile_mode): # 1.7G
n, c, h0, w0 = x.shape
if tile_mode == 0: # 不tile
ph = ((h0 - 1) // 2 + 1) * 2
pw = ((w0 - 1) // 2 + 1) * 2
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0),
"reflect") # 需要保证被2整除
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = torch.add(x0, x1)
if w0 != pw or h0 != ph:
x = x[:, :, : h0 * 2, : w0 * 2]
return x
elif tile_mode == 1: # 对长边减半
if w0 >= h0:
crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
else:
crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
crop_size = (crop_size_h, crop_size_w) # 6.6G
elif tile_mode == 2: # hw都减半
crop_size = (
((h0 - 1) // 4 * 4 + 4) // 2,
((w0 - 1) // 4 * 4 + 4) // 2,
) # 5.6G
elif tile_mode == 3: # hw都三分之一
crop_size = (
((h0 - 1) // 6 * 6 + 6) // 3,
((w0 - 1) // 6 * 6 + 6) // 3,
) # 4.2G
elif tile_mode == 4: # hw都四分之一
crop_size = (
((h0 - 1) // 8 * 8 + 8) // 4,
((w0 - 1) // 8 * 8 + 8) // 4,
) # 3.7G
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect")
n, c, h, w = x.shape
se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
for i in range(0, h - 36, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - 36, crop_size[1]):
x_crop = x[:, :, i: i + crop_size[0] +
36, j: j + crop_size[1] + 36]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
x_crop.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(
x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x2.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean0 = se_mean0.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(
tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x3.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(
tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x4.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - 36, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - 36, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(
tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = torch.add(x0, x1) # x0是unet2的最终输出
opt_res_dict[i][j] = x_crop
del tmp_dict
torch.cuda.empty_cache()
res = torch.zeros((n, c, h * 2 - 72, w * 2 - 72)).to(x.device)
if "Half" in x.type():
res = res.half()
for i in range(0, h - 36, crop_size[0]):
for j in range(0, w - 36, crop_size[1]):
res[
:, :, i * 2: i * 2 + h1 * 2 - 72, j * 2: j * 2 + w1 * 2 - 72
] = opt_res_dict[i][j]
del opt_res_dict
torch.cuda.empty_cache()
if w0 != pw or h0 != ph:
res = res[:, :, : h0 * 2, : w0 * 2]
return res #
class UpCunet3x(nn.Module): # 完美tile,全程无损
def __init__(self, in_channels=3, out_channels=3):
super(UpCunet3x, self).__init__()
self.unet1 = UNet1x3(in_channels, out_channels, deconv=True)
self.unet2 = UNet2(in_channels, out_channels, deconv=False)
def forward(self, x, tile_mode): # 1.7G
n, c, h0, w0 = x.shape
if tile_mode == 0: # 不tile
ph = ((h0 - 1) // 4 + 1) * 4
pw = ((w0 - 1) // 4 + 1) * 4
x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0),
"reflect") # 需要保证被2整除
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = torch.add(x0, x1)
if w0 != pw or h0 != ph:
x = x[:, :, : h0 * 3, : w0 * 3]
return x
elif tile_mode == 1: # 对长边减半
if w0 >= h0:
crop_size_w = ((w0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
crop_size_h = (h0 - 1) // 4 * 4 + 4 # 能被4整除
else:
crop_size_h = ((h0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
crop_size_w = (w0 - 1) // 4 * 4 + 4 # 能被4整除
crop_size = (crop_size_h, crop_size_w) # 6.6G
elif tile_mode == 2: # hw都减半
crop_size = (
((h0 - 1) // 8 * 8 + 8) // 2,
((w0 - 1) // 8 * 8 + 8) // 2,
) # 5.6G
elif tile_mode == 3: # hw都三分之一
crop_size = (
((h0 - 1) // 12 * 12 + 12) // 3,
((w0 - 1) // 12 * 12 + 12) // 3,
) # 4.2G
elif tile_mode == 4: # hw都四分之一
crop_size = (
((h0 - 1) // 16 * 16 + 16) // 4,
((w0 - 1) // 16 * 16 + 16) // 4,
) # 3.7G
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
n, c, h, w = x.shape
se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
for i in range(0, h - 28, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - 28, crop_size[1]):
x_crop = x[:, :, i: i + crop_size[0] +
28, j: j + crop_size[1] + 28]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
x_crop.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(
x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x2.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean0 = se_mean0.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(
tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x3.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(
tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x4.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - 28, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - 28, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(
tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = torch.add(x0, x1) # x0是unet2的最终输出
opt_res_dict[i][j] = x_crop #
del tmp_dict
torch.cuda.empty_cache()
res = torch.zeros((n, c, h * 3 - 84, w * 3 - 84)).to(x.device)
if "Half" in x.type():
res = res.half()
for i in range(0, h - 28, crop_size[0]):
for j in range(0, w - 28, crop_size[1]):
res[
:, :, i * 3: i * 3 + h1 * 3 - 84, j * 3: j * 3 + w1 * 3 - 84
] = opt_res_dict[i][j]
del opt_res_dict
torch.cuda.empty_cache()
if w0 != pw or h0 != ph:
res = res[:, :, : h0 * 3, : w0 * 3]
return res
class UpCunet4x(nn.Module): # 完美tile,全程无损
def __init__(self, in_channels=3, out_channels=3):
super(UpCunet4x, self).__init__()
self.unet1 = UNet1(in_channels, 64, deconv=True)
self.unet2 = UNet2(64, 64, deconv=False)
self.ps = nn.PixelShuffle(2)
self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
def forward(self, x, tile_mode):
n, c, h0, w0 = x.shape
x00 = x
if tile_mode == 0: # 不tile
ph = ((h0 - 1) // 2 + 1) * 2
pw = ((w0 - 1) // 2 + 1) * 2
x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0),
"reflect") # 需要保证被2整除
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = torch.add(x0, x1)
x = self.conv_final(x)
x = F.pad(x, (-1, -1, -1, -1))
x = self.ps(x)
if w0 != pw or h0 != ph:
x = x[:, :, : h0 * 4, : w0 * 4]
x += F.interpolate(x00, scale_factor=4, mode="nearest")
return x
elif tile_mode == 1: # 对长边减半
if w0 >= h0:
crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
else:
crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
crop_size = (crop_size_h, crop_size_w) # 6.6G
elif tile_mode == 2: # hw都减半
crop_size = (
((h0 - 1) // 4 * 4 + 4) // 2,
((w0 - 1) // 4 * 4 + 4) // 2,
) # 5.6G
elif tile_mode == 3: # hw都三分之一
crop_size = (
((h0 - 1) // 6 * 6 + 6) // 3,
((w0 - 1) // 6 * 6 + 6) // 3,
) # 4.1G
elif tile_mode == 4: # hw都四分之一
crop_size = (
((h0 - 1) // 8 * 8 + 8) // 4,
((w0 - 1) // 8 * 8 + 8) // 4,
) # 3.7G
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")
n, c, h, w = x.shape
se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
if "Half" in x.type():
se_mean0 = se_mean0.half()
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
for i in range(0, h - 38, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - 38, crop_size[1]):
x_crop = x[:, :, i: i + crop_size[0] +
38, j: j + crop_size[1] + 38]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
x_crop.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(
x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x2.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean0 = se_mean0.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(
tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x3.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
if "Half" in x.type():
se_mean1 = se_mean1.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(
tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
tmp_se_mean = torch.mean(
tmp_x4.float(), dim=(2, 3), keepdim=True
).half()
else:
tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - 38, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - 38, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(
tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = torch.add(x0, x1) # x0是unet2的最终输出
x_crop = self.conv_final(x_crop)
x_crop = F.pad(x_crop, (-1, -1, -1, -1))
x_crop = self.ps(x_crop)
opt_res_dict[i][j] = x_crop
del tmp_dict
torch.cuda.empty_cache()
res = torch.zeros((n, c, h * 4 - 152, w * 4 - 152)).to(x.device)
if "Half" in x.type():
res = res.half()
for i in range(0, h - 38, crop_size[0]):
for j in range(0, w - 38, crop_size[1]):
# print(opt_res_dict[i][j].shape,res[:, :, i * 4:i * 4 + h1 * 4 - 144, j * 4:j * 4 + w1 * 4 - 144].shape)
res[
:, :, i * 4: i * 4 + h1 * 4 - 152, j * 4: j * 4 + w1 * 4 - 152
] = opt_res_dict[i][j]
del opt_res_dict
torch.cuda.empty_cache()
if w0 != pw or h0 != ph:
res = res[:, :, : h0 * 4, : w0 * 4]
res += F.interpolate(x00, scale_factor=4, mode="nearest")
return res #
models: Dict[str, Type[nn.Module]] = {
obj.__name__: obj
for obj in globals().values()
if isinstance(obj, type) and issubclass(obj, nn.Module)
}
class RealWaifuUpScaler:
def __init__(self, scale: int, weight_path: str, half: bool, device: str):
weight = torch.load(weight_path, map_location=device)
self.model = models[f"UpCunet{scale}x"]()
if half == True:
self.model = self.model.half().to(device)
else:
self.model = self.model.to(device)
self.model.load_state_dict(weight, strict=True)
self.model.eval()
self.half = half
self.device = device
def np2tensor(self, np_frame):
if self.half == False:
return (
torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
.unsqueeze(0)
.to(self.device)
.float()
/ 255
)
else:
return (
torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
.unsqueeze(0)
.to(self.device)
.half()
/ 255
)
def tensor2np(self, tensor):
if self.half == False:
return np.transpose(
(tensor.data.squeeze() * 255.0)
.round()
.clamp_(0, 255)
.byte()
.cpu()
.numpy(),
(1, 2, 0),
)
else:
return np.transpose(
(tensor.data.squeeze().float() * 255.0)
.round()
.clamp_(0, 255)
.byte()
.cpu()
.numpy(),
(1, 2, 0),
)
def __call__(self, frame, tile_mode):
with torch.no_grad():
tensor = self.np2tensor(frame)
result = self.tensor2np(self.model(tensor, tile_mode))
return result
input_image = inputs.File(label="Input image")
half_precision = inputs.Checkbox(
label="Half precision (NOT work for CPU)", default=False
)
model_weight = inputs.Dropdown(
sorted(AVALIABLE_WEIGHTS), label="Choice model weight")
tile_mode = inputs.Radio(
[mode.name for mode in TileMode], label="Output tile mode")
output_image = outputs.Image(label="Output image preview")
output_file = outputs.File(label="Output image file")
def main(file: IO[bytes], half: bool, weight: str, tile: str):
scale = next(
mode.value for mode in ScaleMode if weight.startswith(mode.name))
upscaler = RealWaifuUpScaler(
scale, weight_path=str(AVALIABLE_WEIGHTS[weight]), half=half, device=DEVICE
)
frame = cv2.cvtColor(cv2.imread(file.name), cv2.COLOR_BGR2RGB)
result = cv2.cvtColor(upscaler(frame, TileMode[tile]), cv2.COLOR_RGB2BGR)
_, ext = os.path.splitext(file.name)
tempfile = mktemp(suffix=ext)
cv2.imwrite(tempfile, result)
return tempfile, tempfile
interface = Interface(
main,
inputs=[input_image, half_precision, model_weight, tile_mode],
outputs=[output_image, output_file],
)
interface.launch()