import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torchvision.transforms import transforms import numpy as np from PIL import Image class SRCNNModel(nn.Module): def __init__(self): super(SRCNNModel, self).__init__() self.conv1=nn.Conv2d(1,64,9,padding=4) self.conv2=nn.Conv2d(64,32,1,padding=0) self.conv3=nn.Conv2d(32,1,5,padding=2) def forward(self,x): out = F.relu(self.conv1(x)) out = F.relu(self.conv2(out)) out = self.conv3(out) return out def pred_SRCNN(model,image,device,scale_factor=2): """ model: SRCNN model image: low resolution image PILLOW image scale_factor: scale factor for resolution device: cuda or cpu """ model.to(device) model.eval() # open image # image = Image.open(image_path) # split channels y, cb, cr= image.convert('YCbCr').split() # size will be used in image transform original_size = y.size # bicubic interpolate it to the original size y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y) cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb) cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr) # turn it into tensor and add batch dimension y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0) # get the y channel SRCNN prediction y_pred = model(y_bicubic) # convert it to numpy image y_pred = y_pred[0].cpu().detach().numpy() # convert it into regular image pixel values y_pred = y_pred*255 y_pred.clip(0,255) # conver y channel from array to PIL image format for merging y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L') # merge the SRCNN y channel with cb cr channels out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB') image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image) return out_final,image_bicubic,image def main(): print("Loading SRCNN model...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = SRCNNModel().to(device) model.load_state_dict(torch.load('SRCNNmodel_trained.pt')) model.eval() print("SRCNN model loaded!") image_path = "LR_image.png" out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device) image.show() out_final.show() image_bicubic.show() if __name__=="__main__": main()