SuperResolution / app.py
Hu
update model
a58728b
raw
history blame
2.15 kB
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import SRCNNModel, pred_SRCNN
from PIL import Image
title = "Super Resolution with CNN"
description = """
Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br>
Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br>
"""
article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>πŸ“œ <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p>
<p>πŸ“¦ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p>
</div>
"""
examples = [
["LR_image.png"],
["barbara.png"],
]
# load model
# print("Loading SRCNN model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SRCNNModel().to(device)
model.load_state_dict(
torch.load("SRCNNmodel_trained.pt", map_location=torch.device(device))
)
model.eval()
# print("SRCNN model loaded!")
# def image_grid(imgs, rows, cols):
# '''
# imgs:list of PILImage
# '''
# assert len(imgs) == rows*cols
# w, h = imgs[0].size
# grid = Image.new('RGB', size=(cols*w, rows*h))
# grid_w, grid_h = grid.size
# for i, img in enumerate(imgs):
# grid.paste(img, box=(i%cols*w, i//cols*h))
# return grid
def sepia(image):
# gradio open image as np array
image = Image.fromarray(image, mode="RGB")
# prediction
with torch.no_grad():
out_final, image_bicubic, image = pred_SRCNN(
model=model, image=image, device=device
)
# grid = image_grid([out_final,image_bicubic],1,2)
return out_final, image_bicubic
demo = gr.Interface(
fn=sepia,
inputs=gr.inputs.Image(label="Upload image"),
outputs=[
gr.outputs.Image(label="Convolutional neural network"),
gr.outputs.Image(label="Bicubic interpoloation"),
],
title=title,
description=description,
article=article,
examples=examples,
)
demo.launch()