ESRGAN-MANGA / inference_manga.py
0x90e's picture
Fixes.
a7d5db1
raw
history blame
1.38 kB
import sys
import os.path
import cv2
import numpy as np
import torch
import architecture as arch
def is_cuda():
if torch.cuda.is_available():
return True
else:
return False
model_path = '4x_eula_digimanga_bw_v2_nc1_307k.pth'
img_path = sys.argv[1]
output_dir = sys.argv[2]
device = torch.device('cuda' if is_cuda() else 'cpu')
model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
if is_cuda():
model.load_state_dict(torch.load(model_path), strict=True)
else:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
base = os.path.splitext(os.path.basename(img_path))[0]
# read image
print(img_path);
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = img * 1.0 / 255
img = torch.from_numpy(img[np.newaxis, :, :]).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
print('Start upscaling...')
with torch.no_grad():
output = model(img_LR).squeeze(dim=0).float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
print('Finished upscaling, saving image.')
print(output_dir)
cv2.imwrite(output_dir, output, [int(cv2.IMWRITE_JPEG_QUALITY), 90])