File size: 2,725 Bytes
2f110b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
import numpy as np
from PIL import Image

class SRCNNModel(nn.Module):
  def __init__(self):
    super(SRCNNModel, self).__init__()
    self.conv1=nn.Conv2d(1,64,9,padding=4)
    self.conv2=nn.Conv2d(64,32,1,padding=0)
    self.conv3=nn.Conv2d(32,1,5,padding=2)

  def forward(self,x):
    out = F.relu(self.conv1(x))
    out = F.relu(self.conv2(out))
    out = self.conv3(out)
    return out

def pred_SRCNN(model,image,device,scale_factor=2):
    """
    model: SRCNN model
    image: low resolution image PILLOW image
    scale_factor: scale factor for resolution
    device: cuda or cpu
    """
    model.to(device)
    model.eval()

    # open image
    # image = Image.open(image_path)
    # split channels
    y, cb, cr= image.convert('YCbCr').split()
    # size will be used in image transform
    original_size = y.size 

    # bicubic interpolate it to the original size
    y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y)
    cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb)
    cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr)
    # turn it into tensor and add batch dimension
    y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0) 
    # get the y channel SRCNN prediction
    y_pred = model(y_bicubic)
    # convert it to numpy image
    y_pred = y_pred[0].cpu().detach().numpy()

    # convert it into regular image pixel values
    y_pred = y_pred*255 
    y_pred.clip(0,255)
    # conver y channel from array to PIL image format for merging
    y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L')
    # merge the SRCNN y channel with cb cr channels
    out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB') 

    image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image)
    return out_final,image_bicubic,image


def main():
    print("Loading  SRCNN model...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = SRCNNModel().to(device)
    model.load_state_dict(torch.load('SRCNNmodel_trained.pt'))
    model.eval()
    print("SRCNN model loaded!")

    image_path = "LR_image.png"

    out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device)
    image.show()
    out_final.show()
    image_bicubic.show()


if __name__=="__main__":
    main()