Spaces:
Runtime error
Runtime error
File size: 1,782 Bytes
0d5c9d2 ceac432 93c8609 3aa67c0 17cfe57 d2b0313 6887d0a 4ce0baf 6887d0a ceac432 5e08d25 dc5de93 4d85bc3 d570bef 93c8609 dc5de93 227e514 dc5de93 5e08d25 6887d0a ceac432 93c8609 9f0dff5 425da69 9f0dff5 425da69 9f0dff5 4ce0baf ceac432 0d5c9d2 ceac432 17cfe57 5e08d25 17cfe57 3aa67c0 17cfe57 3aa67c0 17cfe57 5e08d25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import sys
import cv2
import numpy as np
import torch
import ESRGAN.architecture as esrgan
import ESRGAN_plus.architecture as esrgan_plus
from run_cmd import run_cmd
from ESRGANer import ESRGANer
def is_cuda():
if torch.cuda.is_available():
return True
else:
return False
model_type = sys.argv[2]
if model_type == "Anime":
model_path = "models/4x-AnimeSharp.pth"
if model_type == "Photo":
model_path = "models/4x_Valar_v1.pth"
else:
model_path = "models/4x_NMKD-Siax_200k.pth"
OUTPUT_PATH = sys.argv[1]
device = torch.device('cuda' if is_cuda() else 'cpu')
if model_type != "Photo":
model = esrgan.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
else:
model = esrgan_plus.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
if is_cuda():
print("Using GPU ๐ฅถ")
model.load_state_dict(torch.load(model_path), strict=True)
else:
print("Using CPU ๐")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
# Read image
img = cv2.imread(OUTPUT_PATH, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
upsampler = ESRGANer(model=model)
output = upsampler.enhance(img_LR)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite(OUTPUT_PATH, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 5]) |