Hu commited on
Commit
aa0b24b
·
1 Parent(s): 0723193

move model inside app.py

Browse files
Files changed (2) hide show
  1. app.py +69 -1
  2. model.py +0 -80
app.py CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from model import SRCNNModel, pred_SRCNN
7
  from PIL import Image
8
 
9
 
@@ -28,6 +27,75 @@ examples = [
28
  ["barbara.png"],
29
  ]
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # load model
32
  # print("Loading SRCNN model...")
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
6
  from PIL import Image
7
 
8
 
 
27
  ["barbara.png"],
28
  ]
29
 
30
+
31
+ class SRCNNModel(nn.Module):
32
+ def __init__(self):
33
+ super(SRCNNModel, self).__init__()
34
+ self.conv1 = nn.Conv2d(1, 64, 9, padding=4)
35
+ self.conv2 = nn.Conv2d(64, 32, 1, padding=0)
36
+ self.conv3 = nn.Conv2d(32, 1, 5, padding=2)
37
+
38
+ def forward(self, x):
39
+ out = F.relu(self.conv1(x))
40
+ out = F.relu(self.conv2(out))
41
+ out = self.conv3(out)
42
+ return out
43
+
44
+
45
+ def pred_SRCNN(model, image, device, scale_factor=2):
46
+ """
47
+ model: SRCNN model
48
+ image: low resolution image PILLOW image
49
+ scale_factor: scale factor for resolution
50
+ device: cuda or cpu
51
+ """
52
+ model.to(device)
53
+ model.eval()
54
+
55
+ # open image
56
+ # image = Image.open(image_path)
57
+ # split channels
58
+ y, cb, cr = image.convert("YCbCr").split()
59
+ # size will be used in image transform
60
+ original_size = y.size
61
+
62
+ # bicubic interpolate it to the original size
63
+ y_bicubic = transforms.Resize(
64
+ (original_size[1] * scale_factor, original_size[0] * scale_factor),
65
+ interpolation=Image.BICUBIC,
66
+ )(y)
67
+ cb_bicubic = transforms.Resize(
68
+ (original_size[1] * scale_factor, original_size[0] * scale_factor),
69
+ interpolation=Image.BICUBIC,
70
+ )(cb)
71
+ cr_bicubic = transforms.Resize(
72
+ (original_size[1] * scale_factor, original_size[0] * scale_factor),
73
+ interpolation=Image.BICUBIC,
74
+ )(cr)
75
+ # turn it into tensor and add batch dimension
76
+ y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0)
77
+ # get the y channel SRCNN prediction
78
+ y_pred = model(y_bicubic)
79
+ # convert it to numpy image
80
+ y_pred = y_pred[0].cpu().detach().numpy()
81
+
82
+ # convert it into regular image pixel values
83
+ y_pred = y_pred * 255
84
+ y_pred.clip(0, 255)
85
+ # conver y channel from array to PIL image format for merging
86
+ y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]), mode="L")
87
+ # merge the SRCNN y channel with cb cr channels
88
+ out_final = Image.merge("YCbCr", [y_pred_PIL, cb_bicubic, cr_bicubic]).convert(
89
+ "RGB"
90
+ )
91
+
92
+ image_bicubic = transforms.Resize(
93
+ (original_size[1] * scale_factor, original_size[0] * scale_factor),
94
+ interpolation=Image.BICUBIC,
95
+ )(image)
96
+ return out_final, image_bicubic, image
97
+
98
+
99
  # load model
100
  # print("Loading SRCNN model...")
101
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.py DELETED
@@ -1,80 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision
5
- from torchvision.transforms import transforms
6
- import numpy as np
7
- from PIL import Image
8
-
9
- class SRCNNModel(nn.Module):
10
- def __init__(self):
11
- super(SRCNNModel, self).__init__()
12
- self.conv1=nn.Conv2d(1,64,9,padding=4)
13
- self.conv2=nn.Conv2d(64,32,1,padding=0)
14
- self.conv3=nn.Conv2d(32,1,5,padding=2)
15
-
16
- def forward(self,x):
17
- out = F.relu(self.conv1(x))
18
- out = F.relu(self.conv2(out))
19
- out = self.conv3(out)
20
- return out
21
-
22
- def pred_SRCNN(model,image,device,scale_factor=2):
23
- """
24
- model: SRCNN model
25
- image: low resolution image PILLOW image
26
- scale_factor: scale factor for resolution
27
- device: cuda or cpu
28
- """
29
- model.to(device)
30
- model.eval()
31
-
32
- # open image
33
- # image = Image.open(image_path)
34
- # split channels
35
- y, cb, cr= image.convert('YCbCr').split()
36
- # size will be used in image transform
37
- original_size = y.size
38
-
39
- # bicubic interpolate it to the original size
40
- y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y)
41
- cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb)
42
- cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr)
43
- # turn it into tensor and add batch dimension
44
- y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0)
45
- # get the y channel SRCNN prediction
46
- y_pred = model(y_bicubic)
47
- # convert it to numpy image
48
- y_pred = y_pred[0].cpu().detach().numpy()
49
-
50
- # convert it into regular image pixel values
51
- y_pred = y_pred*255
52
- y_pred.clip(0,255)
53
- # conver y channel from array to PIL image format for merging
54
- y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L')
55
- # merge the SRCNN y channel with cb cr channels
56
- out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB')
57
-
58
- image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image)
59
- return out_final,image_bicubic,image
60
-
61
-
62
- # def main():
63
- # print("Loading SRCNN model...")
64
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
-
66
- # model = SRCNNModel().to(device)
67
- # model.load_state_dict(torch.load('SRCNNmodel_trained.pt'))
68
- # model.eval()
69
- # print("SRCNN model loaded!")
70
-
71
- # image_path = "LR_image.png"
72
-
73
- # out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device)
74
- # image.show()
75
- # out_final.show()
76
- # image_bicubic.show()
77
-
78
-
79
- # if __name__=="__main__":
80
- # main()