File size: 1,782 Bytes
0d5c9d2
ceac432
 
 
93c8609
 
3aa67c0
17cfe57
d2b0313
6887d0a
4ce0baf
6887d0a
 
 
ceac432
5e08d25
dc5de93
4d85bc3
d570bef
93c8609
 
dc5de93
227e514
dc5de93
5e08d25
6887d0a
ceac432
93c8609
 
 
 
9f0dff5
 
425da69
9f0dff5
 
425da69
9f0dff5
4ce0baf
ceac432
0d5c9d2
 
 
ceac432
 
17cfe57
5e08d25
17cfe57
 
 
 
3aa67c0
17cfe57
 
3aa67c0
17cfe57
 
 
5e08d25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys
import cv2
import numpy as np
import torch
import ESRGAN.architecture as esrgan
import ESRGAN_plus.architecture as esrgan_plus
from run_cmd import run_cmd
from ESRGANer import ESRGANer

def is_cuda():
    if torch.cuda.is_available():
        return True
    else:
        return False

model_type = sys.argv[2]

if model_type == "Anime":
    model_path = "models/4x-AnimeSharp.pth"
if model_type == "Photo":
    model_path = "models/4x_Valar_v1.pth"
else:
    model_path = "models/4x_NMKD-Siax_200k.pth"

OUTPUT_PATH = sys.argv[1]
device = torch.device('cuda' if is_cuda() else 'cpu')

if model_type != "Photo":
    model = esrgan.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
else:
    model = esrgan_plus.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')

if is_cuda():
    print("Using GPU ๐Ÿฅถ")
    model.load_state_dict(torch.load(model_path), strict=True)
else:
    print("Using CPU ๐Ÿ˜’")
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)

model.eval()

for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)

# Read image
img = cv2.imread(OUTPUT_PATH, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)

upsampler = ESRGANer(model=model)
output = upsampler.enhance(img_LR)

output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite(OUTPUT_PATH, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 5])