Spaces:
Running
Running
File size: 2,371 Bytes
d863531 b4d55e3 d863531 b6a4ee3 d863531 e708547 d863531 b6a4ee3 d863531 e708547 d863531 b6a4ee3 d863531 b4d55e3 d863531 b4d55e3 d863531 e708547 b6a4ee3 e708547 b4d55e3 770d74c 5440d10 b4d55e3 d863531 e708547 b6a4ee3 e708547 2eef120 d863531 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
import torch.nn.functional as F
import gradio as gr
# import sys
# sys.path.insert(0, './')
from test import create_letr, get_lines_and_draw
from models.preprocessing import *
from models.misc import nested_tensor_from_tensor_list
model = create_letr('resnet50/checkpoint0024.pth')
model101 = create_letr('resnet101/checkpoint0024.pth')
# PREPARE PREPROCESSING
# transform_test = transforms.Compose([
# transforms.Resize((test_size)),
# transforms.ToTensor(),
# transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
# ])
normalize = Compose([
ToTensor(),
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
Resize([256]),
])
normalize_512 = Compose([
ToTensor(),
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
Resize([512]),
])
normalize_1100 = Compose([
ToTensor(),
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
Resize([1100]),
])
def predict(inp, size, model_name):
image = Image.fromarray(inp.astype('uint8'), 'RGB')
h, w = image.height, image.width
orig_size = torch.as_tensor([int(h), int(w)])
if size == '1100':
img = normalize_1100(image)
elif size == '512':
img = normalize_512(image)
else:
img = normalize(image)
inputs = nested_tensor_from_tensor_list([img])
with torch.no_grad():
if model_name == 'resnet101':
outputs = model101(inputs)[0]
else:
outputs = model(inputs)[0]
lines = get_lines_and_draw(image, outputs, orig_size)
return image, str(lines)
inputs = [
gr.inputs.Image(),
gr.inputs.Radio(["256", "512", "1100"]),
gr.inputs.Radio(["resnet50", "resnet101"]),
]
outputs = [
gr.outputs.Image(label='Image with Lines', type='numpy'),
gr.outputs.Textbox(label='Lines points List')
]
gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
examples=[
["demo.png", '256', "resnet50"],
["tappeto-per-calibrazione.jpg", '256', "resnet50"]
],
title="LETR: Line Segment Detection Using Transformers without Edges",
description="It is an end-to-end line segment detection algorithm using Transformers [published on CVPR 2021](https://github.com/mlpc-ucsd/LETR)."
).launch()
|