Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,12 +4,16 @@ from transformers import DetrImageProcessor, DetrForObjectDetection
|
|
4 |
from color import Color
|
5 |
from color_wheel import ColorWheel
|
6 |
from PIL import ImageDraw, ImageFont
|
|
|
7 |
|
8 |
-
|
9 |
-
model
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
def process_image(image, margin):
|
13 |
if image is None:
|
14 |
yield [None, None, None]
|
15 |
return
|
@@ -30,6 +34,7 @@ def process_image(image, margin):
|
|
30 |
index = 0
|
31 |
gallery = []
|
32 |
labels = []
|
|
|
33 |
drawImage = image.copy()
|
34 |
draw = ImageDraw.Draw(drawImage)
|
35 |
for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
|
@@ -43,8 +48,14 @@ def process_image(image, margin):
|
|
43 |
draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=colors[index], width=4)
|
44 |
gallery.append(image.crop((box[0], box[1], box[2], box[3])))
|
45 |
labels.append(model.config.id2label[label.item()])
|
|
|
|
|
46 |
index += 1
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
app = gr.Interface(
|
50 |
title='Object Detection for Image',
|
@@ -52,14 +63,17 @@ app = gr.Interface(
|
|
52 |
inputs=[
|
53 |
gr.Image(type='pil'),
|
54 |
gr.Slider(maximum=100, step=1, label='margin'),
|
|
|
55 |
],
|
56 |
outputs=[
|
57 |
gr.Image(label='boxes', type='pil'),
|
58 |
gr.Gallery(label='gallery', columns=8, height=140),
|
|
|
59 |
gr.Textbox(label='text'),
|
60 |
],
|
61 |
allow_flagging='never',
|
62 |
examples=[['examples/Wild.jpg', 0], ['examples/Football-Match.jpg', 0]],
|
|
|
63 |
#cache_examples=False
|
64 |
)
|
65 |
app.queue(concurrency_count=20)
|
|
|
4 |
from color import Color
|
5 |
from color_wheel import ColorWheel
|
6 |
from PIL import ImageDraw, ImageFont
|
7 |
+
import numpy as np
|
8 |
|
9 |
+
def process_image(image, margin, model):
|
10 |
+
if model=='detr-resnet-101':
|
11 |
+
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
|
12 |
+
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
|
13 |
+
else:
|
14 |
+
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
|
15 |
+
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
|
16 |
|
|
|
|
|
17 |
if image is None:
|
18 |
yield [None, None, None]
|
19 |
return
|
|
|
34 |
index = 0
|
35 |
gallery = []
|
36 |
labels = []
|
37 |
+
newlabel = {}
|
38 |
drawImage = image.copy()
|
39 |
draw = ImageDraw.Draw(drawImage)
|
40 |
for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
|
|
|
48 |
draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=colors[index], width=4)
|
49 |
gallery.append(image.crop((box[0], box[1], box[2], box[3])))
|
50 |
labels.append(model.config.id2label[label.item()])
|
51 |
+
print(model.config.id2label[label.item()])
|
52 |
+
newlabel[model.config.id2label[label.item()]] = 1
|
53 |
index += 1
|
54 |
+
print('----------')
|
55 |
+
print(labels)
|
56 |
+
print(results['scores'])
|
57 |
+
print(newlabel)
|
58 |
+
yield [drawImage, gallery, newlabel, ','.join(labels)]
|
59 |
|
60 |
app = gr.Interface(
|
61 |
title='Object Detection for Image',
|
|
|
63 |
inputs=[
|
64 |
gr.Image(type='pil'),
|
65 |
gr.Slider(maximum=100, step=1, label='margin'),
|
66 |
+
gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Select the model")
|
67 |
],
|
68 |
outputs=[
|
69 |
gr.Image(label='boxes', type='pil'),
|
70 |
gr.Gallery(label='gallery', columns=8, height=140),
|
71 |
+
gr.Label(label='scores'),
|
72 |
gr.Textbox(label='text'),
|
73 |
],
|
74 |
allow_flagging='never',
|
75 |
examples=[['examples/Wild.jpg', 0], ['examples/Football-Match.jpg', 0]],
|
76 |
+
# theme='HaleyCH/HaleyCH_Theme',
|
77 |
#cache_examples=False
|
78 |
)
|
79 |
app.queue(concurrency_count=20)
|