File size: 1,956 Bytes
33ac1eb
 
 
 
 
 
 
 
 
 
 
 
840ed11
c373d06
 
33ac1eb
 
 
c373d06
 
 
 
 
 
 
33ac1eb
 
 
 
 
 
60caba2
33ac1eb
 
 
840ed11
 
 
 
 
33ac1eb
840ed11
 
 
33ac1eb
840ed11
 
 
33ac1eb
840ed11
33ac1eb
840ed11
33ac1eb
840ed11
 
33ac1eb
feb4f18
33ac1eb
277674a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import SRCNNModel, pred_SRCNN
from PIL import Image


title = "Super Resolution with CNN"
description = """

Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br>

Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br>

"""

article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>๐Ÿ“œ <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p>
<p>๐Ÿ“ฆ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p>
</div>
"""

# load model
print("Loading  SRCNN model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SRCNNModel().to(device)
model.load_state_dict(torch.load('SRCNNmodel_trained.pt',map_location=torch.device(device) ))
model.eval()
print("SRCNN model loaded!")

# def image_grid(imgs, rows, cols):
#     '''
#     imgs:list of PILImage
#     '''
#     assert len(imgs) == rows*cols

#     w, h = imgs[0].size
#     grid = Image.new('RGB', size=(cols*w, rows*h))
#     grid_w, grid_h = grid.size
    
#     for i, img in enumerate(imgs):
#         grid.paste(img, box=(i%cols*w, i//cols*h))
#     return grid

def sepia(image):
    # gradio open image as np array
    image = Image.fromarray(image,mode='RGB')
    out_final,image_bicubic,image = pred_SRCNN(model=model,image=image,device=device)
    # grid = image_grid([out_final,image_bicubic],1,2)
    return out_final,image_bicubic

demo = gr.Interface(fn = sepia, inputs=gr.inputs.Image(label="Upload image"), outputs=[gr.outputs.Image(label="Conv net"), gr.outputs.Image(label="Bicubic interpoloation")],title=title,description = description,article = article,examples=[['LR_image.png'],['barbara.png']])

demo.launch()