Spaces:
Running
Running
import numpy as np | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from model import SRCNNModel, pred_SRCNN | |
from PIL import Image | |
title = "Super Resolution with CNN" | |
description = """ | |
Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br> | |
Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br> | |
""" | |
article = """ | |
<div style='margin:20px auto;'> | |
<p>Sources:<p> | |
<p>π <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p> | |
<p>π¦ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p> | |
</div> | |
""" | |
examples = [ | |
["LR_image.png"], | |
["barbara.png"], | |
] | |
# load model | |
# print("Loading SRCNN model...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SRCNNModel().to(device) | |
model.load_state_dict( | |
torch.load("SRCNNmodel_trained.pt", map_location=torch.device(device)) | |
) | |
model.eval() | |
# print("SRCNN model loaded!") | |
# def image_grid(imgs, rows, cols): | |
# ''' | |
# imgs:list of PILImage | |
# ''' | |
# assert len(imgs) == rows*cols | |
# w, h = imgs[0].size | |
# grid = Image.new('RGB', size=(cols*w, rows*h)) | |
# grid_w, grid_h = grid.size | |
# for i, img in enumerate(imgs): | |
# grid.paste(img, box=(i%cols*w, i//cols*h)) | |
# return grid | |
def sepia(image): | |
# gradio open image as np array | |
image = Image.fromarray(image, mode="RGB") | |
# prediction | |
with torch.no_grad(): | |
out_final, image_bicubic, image = pred_SRCNN( | |
model=model, image=image, device=device | |
) | |
# grid = image_grid([out_final,image_bicubic],1,2) | |
return out_final, image_bicubic | |
demo = gr.Interface( | |
fn=sepia, | |
inputs=gr.inputs.Image(label="Upload image"), | |
outputs=[ | |
gr.outputs.Image(label="Convolutional neural network"), | |
gr.outputs.Image(label="Bicubic interpoloation"), | |
], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
) | |
demo.launch() | |