Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
from torchvision.transforms import transforms | |
import numpy as np | |
from PIL import Image | |
class SRCNNModel(nn.Module): | |
def __init__(self): | |
super(SRCNNModel, self).__init__() | |
self.conv1=nn.Conv2d(1,64,9,padding=4) | |
self.conv2=nn.Conv2d(64,32,1,padding=0) | |
self.conv3=nn.Conv2d(32,1,5,padding=2) | |
def forward(self,x): | |
out = F.relu(self.conv1(x)) | |
out = F.relu(self.conv2(out)) | |
out = self.conv3(out) | |
return out | |
def pred_SRCNN(model,image,device,scale_factor=2): | |
""" | |
model: SRCNN model | |
image: low resolution image PILLOW image | |
scale_factor: scale factor for resolution | |
device: cuda or cpu | |
""" | |
model.to(device) | |
model.eval() | |
# open image | |
# image = Image.open(image_path) | |
# split channels | |
y, cb, cr= image.convert('YCbCr').split() | |
# size will be used in image transform | |
original_size = y.size | |
# bicubic interpolate it to the original size | |
y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y) | |
cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb) | |
cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr) | |
# turn it into tensor and add batch dimension | |
y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0) | |
# get the y channel SRCNN prediction | |
y_pred = model(y_bicubic) | |
# convert it to numpy image | |
y_pred = y_pred[0].cpu().detach().numpy() | |
# convert it into regular image pixel values | |
y_pred = y_pred*255 | |
y_pred.clip(0,255) | |
# conver y channel from array to PIL image format for merging | |
y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L') | |
# merge the SRCNN y channel with cb cr channels | |
out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB') | |
image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image) | |
return out_final,image_bicubic,image | |
def main(): | |
print("Loading SRCNN model...") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = SRCNNModel().to(device) | |
model.load_state_dict(torch.load('SRCNNmodel_trained.pt')) | |
model.eval() | |
print("SRCNN model loaded!") | |
image_path = "LR_image.png" | |
out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device) | |
image.show() | |
out_final.show() | |
image_bicubic.show() | |
if __name__=="__main__": | |
main() | |