File size: 1,552 Bytes
0d5c9d2
 
ceac432
 
 
0d5c9d2
d2b0313
6887d0a
4ce0baf
6887d0a
 
 
ceac432
dc5de93
 
4d85bc3
dc5de93
 
 
 
0d5c9d2
1c6d57c
6887d0a
ceac432
dc5de93
9f0dff5
 
425da69
9f0dff5
 
425da69
9f0dff5
4ce0baf
ceac432
0d5c9d2
 
 
ceac432
 
0d5c9d2
ceac432
0d5c9d2
d4772c3
4798451
2c0c47f
dc5de93
0d5c9d2
 
ceac432
d4772c3
612f527
dc5de93
 
 
d4772c3
4d85bc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import os.path
import cv2
import numpy as np
import torch
import architecture as arch

def is_cuda():
    if torch.cuda.is_available():
        return True
    else:
        return False

model_type = sys.argv[3]

if model_type == "Anime":
    model_path = "4x-AnimeSharp.pth"
else:
    model_path = "4x-UniScaleV2_Sharp.pth"

img_path = sys.argv[1]
output_dir = sys.argv[2]
device = torch.device('cuda' if is_cuda() else 'cpu')

model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')

if is_cuda():
    print("Using GPU ๐Ÿฅถ")
    model.load_state_dict(torch.load(model_path), strict=True)
else:
    print("Using CPU ๐Ÿ˜’")
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)

model.eval()

for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)

base = os.path.splitext(os.path.basename(img_path))[0]

# read image
print(img_path);
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)

print('Start upscaling...')
with torch.no_grad():
    output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
    output = (output * 255.0).round()
print('Finished upscaling, saving image.')
cv2.imwrite(output_dir, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 5])