ESRGAN-MANGA / inference.py
0x90e's picture
UI improvements and easier photo download
d570bef
raw
history blame
1.55 kB
import sys
import os.path
import cv2
import numpy as np
import torch
import architecture as arch
def is_cuda():
if torch.cuda.is_available():
return True
else:
return False
model_type = sys.argv[3]
if model_type == "Anime":
model_path = "models/4x-AnimeSharp.pth"
else:
model_path = "models/4x-UniScaleV2_Sharp.pth"
img_path = sys.argv[1]
output_dir = sys.argv[2]
device = torch.device('cuda' if is_cuda() else 'cpu')
model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
if is_cuda():
print("Using GPU πŸ₯Ά")
model.load_state_dict(torch.load(model_path), strict=True)
else:
print("Using CPU πŸ˜’")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
base = os.path.splitext(os.path.basename(img_path))[0]
# Read image
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
print('Start upscaling...')
with torch.no_grad():
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
print('Finished upscaling, saving image.')
cv2.imwrite(output_dir, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])