ESRGAN-MANGA / test.py
0x90e's picture
Better colab compat
6887d0a
raw
history blame
1.39 kB
import sys
import os.path
import glob
import cv2
import numpy as np
import torch
import architecture as arch
import multiprocessing
import util
def is_cuda():
if torch.cuda.is_available() or not util.is_google_colab():
return True
else:
return False
model_path = '4x_eula_digimanga_bw_v2_nc1_307k.pth'
img_path = sys.argv[1]
output_dir = sys.argv[2]
device = torch.device('cuda' if is_cuda() else 'cpu')
model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if is_cuda() else 'cpu')), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
base = os.path.splitext(os.path.basename(img_path))[0]
# read image
print(img_path);
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = img * 1.0 / 255
img = torch.from_numpy(img[np.newaxis, :, :]).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
print('Start upscaling...')
with torch.no_grad():
output = model(img_LR).squeeze(dim=0).float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
print('Finished upscaling, saving image.')
print(output_dir)
cv2.imwrite(output_dir, output, [int(cv2.IMWRITE_JPEG_QUALITY), 90])