SuperResolution / model.py
Hu
initial commit
2f110b2
raw
history blame
2.73 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
import numpy as np
from PIL import Image
class SRCNNModel(nn.Module):
def __init__(self):
super(SRCNNModel, self).__init__()
self.conv1=nn.Conv2d(1,64,9,padding=4)
self.conv2=nn.Conv2d(64,32,1,padding=0)
self.conv3=nn.Conv2d(32,1,5,padding=2)
def forward(self,x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = self.conv3(out)
return out
def pred_SRCNN(model,image,device,scale_factor=2):
"""
model: SRCNN model
image: low resolution image PILLOW image
scale_factor: scale factor for resolution
device: cuda or cpu
"""
model.to(device)
model.eval()
# open image
# image = Image.open(image_path)
# split channels
y, cb, cr= image.convert('YCbCr').split()
# size will be used in image transform
original_size = y.size
# bicubic interpolate it to the original size
y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y)
cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb)
cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr)
# turn it into tensor and add batch dimension
y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0)
# get the y channel SRCNN prediction
y_pred = model(y_bicubic)
# convert it to numpy image
y_pred = y_pred[0].cpu().detach().numpy()
# convert it into regular image pixel values
y_pred = y_pred*255
y_pred.clip(0,255)
# conver y channel from array to PIL image format for merging
y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L')
# merge the SRCNN y channel with cb cr channels
out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB')
image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image)
return out_final,image_bicubic,image
def main():
print("Loading SRCNN model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRCNNModel().to(device)
model.load_state_dict(torch.load('SRCNNmodel_trained.pt'))
model.eval()
print("SRCNN model loaded!")
image_path = "LR_image.png"
out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device)
image.show()
out_final.show()
image_bicubic.show()
if __name__=="__main__":
main()