Spaces:
Running
Running
from PIL import Image, ImageDraw | |
import torch | |
from torchvision import transforms | |
import torch.nn.functional as F | |
import gradio as gr | |
# import sys | |
# sys.path.insert(0, './') | |
from test import create_letr, get_lines_and_draw | |
from models.preprocessing import * | |
from models.misc import nested_tensor_from_tensor_list | |
model = create_letr('resnet50/checkpoint0024.pth') | |
model101 = create_letr('resnet101/checkpoint0024.pth') | |
# PREPARE PREPROCESSING | |
# transform_test = transforms.Compose([ | |
# transforms.Resize((test_size)), | |
# transforms.ToTensor(), | |
# transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
# ]) | |
normalize = Compose([ | |
ToTensor(), | |
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
Resize([256]), | |
]) | |
normalize_512 = Compose([ | |
ToTensor(), | |
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
Resize([512]), | |
]) | |
normalize_1100 = Compose([ | |
ToTensor(), | |
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
Resize([1100]), | |
]) | |
def predict(inp, size, model_name): | |
image = Image.fromarray(inp.astype('uint8'), 'RGB') | |
h, w = image.height, image.width | |
orig_size = torch.as_tensor([int(h), int(w)]) | |
if size == '1100': | |
img = normalize_1100(image) | |
elif size == '512': | |
img = normalize_512(image) | |
else: | |
img = normalize(image) | |
inputs = nested_tensor_from_tensor_list([img]) | |
with torch.no_grad(): | |
if model_name == 'resnet101': | |
outputs = model101(inputs)[0] | |
else: | |
outputs = model(inputs)[0] | |
lines = get_lines_and_draw(image, outputs, orig_size) | |
return image, str(lines) | |
inputs = [ | |
gr.inputs.Image(), | |
gr.inputs.Radio(["256", "512", "1100"]), | |
gr.inputs.Radio(["resnet50", "resnet101"]), | |
] | |
outputs = [ | |
gr.outputs.Image(label='Image with Lines', type='numpy'), | |
gr.outputs.Textbox(label='Lines points List') | |
] | |
gr.Interface( | |
fn=predict, | |
inputs=inputs, | |
outputs=outputs, | |
examples=[ | |
["demo.png", '256', "resnet50"], | |
["tappeto-per-calibrazione.jpg", '256', "resnet50"] | |
], | |
title="LETR: Line Segment Detection Using Transformers without Edges", | |
description="It is an end-to-end line segment detection algorithm using Transformers [published on CVPR 2021](https://github.com/mlpc-ucsd/LETR)." | |
).launch() | |