Spaces:
Running
Running
File size: 1,859 Bytes
33ac1eb 840ed11 33ac1eb 60caba2 33ac1eb 840ed11 33ac1eb 840ed11 33ac1eb 840ed11 33ac1eb 840ed11 33ac1eb 840ed11 33ac1eb 840ed11 33ac1eb 840ed11 33ac1eb 277674a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import SRCNNModel, pred_SRCNN
from PIL import Image
title = "Super Resolution with CNN"
description = """
Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br>
CNN output on the left, bicubic interpolation output on the right.<br>
Training and dataset can be found on my [github page](https://github.com/susuhu/super-resolution/blob/main/Super_Resolution.ipynb).<br>
"""
article = "Check out the origianl [paper](https://arxiv.org/abs/1501.00092) proposed by Dong *et al*."
# load model
print("Loading SRCNN model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRCNNModel().to(device)
model.load_state_dict(torch.load('SRCNNmodel_trained.pt',map_location=torch.device(device) ))
model.eval()
print("SRCNN model loaded!")
# def image_grid(imgs, rows, cols):
# '''
# imgs:list of PILImage
# '''
# assert len(imgs) == rows*cols
# w, h = imgs[0].size
# grid = Image.new('RGB', size=(cols*w, rows*h))
# grid_w, grid_h = grid.size
# for i, img in enumerate(imgs):
# grid.paste(img, box=(i%cols*w, i//cols*h))
# return grid
def sepia(image):
# gradio open image as np array
image = Image.fromarray(image,mode='RGB')
out_final,image_bicubic,image = pred_SRCNN(model=model,image=image,device=device)
# grid = image_grid([out_final,image_bicubic],1,2)
return out_final,image_bicubic
demo = gr.Interface(fn = sepia, inputs=gr.inputs.Image(label="Upload image"), [gr.outputs.Image(label="Conv net"), gr.outputs.Image(label="Bicubic interpoloation")],title=title,description = description,article = article,examples=[['LR_image.png'],['barbara.png']])
demo.launch() |