Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from models.network_swinir import SwinIR as net | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load pretrained model | |
model = net(img_size=64, in_nc=3, out_nc=3, nf=64, n_resblocks=8).to(device) | |
model.load_state_dict(torch.load('001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth', map_location=device)) | |
model.eval() | |
def process_img(input_image: Image.Image): | |
# Resize to low resolution | |
input_image = input_image.resize((input_image.width // 4, input_image.height // 4)) | |
# Transform to tensor | |
transform = transforms.ToTensor() | |
input_tensor = transform(input_image).unsqueeze(0).to(device) | |
# Use the model to upscale image | |
with torch.no_grad(): | |
output_tensor = model(input_tensor) | |
# Transform the output tensor to image | |
output_image = transforms.ToPILImage()(output_tensor.squeeze().cpu()) | |
return output_image | |
iface = gr.Interface( | |
fn=process_img, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs="image", | |
title="SwinIR upscaling" | |
) | |
iface.launch() | |