edm-research commited on
Commit
f211596
·
1 Parent(s): 3cd95d0

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from transformers import pipeline
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+
11
+ model_pipeline = pipeline("object-detection", model="edm-research/detr-resnet-50-dc5-fashionpedia-finetuned")
12
+
13
+
14
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
15
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
16
+
17
+
18
+ def get_output_figure(pil_img, results, threshold):
19
+ plt.figure(figsize=(16, 10))
20
+ plt.imshow(pil_img)
21
+ ax = plt.gca()
22
+ colors = COLORS * 100
23
+
24
+ for result in results:
25
+ score = result['score']
26
+ label = result['label']
27
+ box = list(result['box'].values())
28
+ if score > threshold:
29
+ c = COLORS[hash(label) % len(COLORS)]
30
+ ax.add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color=c, linewidth=3))
31
+ text = f'{label}: {score:0.2f}'
32
+ ax.text(box[0], box[1], text, fontsize=15,
33
+ bbox=dict(facecolor='yellow', alpha=0.5))
34
+ plt.axis('off')
35
+
36
+ return plt.gcf()
37
+
38
+ @spaces.GPU
39
+ def detect(image):
40
+ results = model_pipeline(image)
41
+ print(results)
42
+
43
+ output_figure = get_output_figure(image, results, threshold=0.7)
44
+
45
+ buf = io.BytesIO()
46
+ output_figure.savefig(buf, bbox_inches='tight')
47
+ buf.seek(0)
48
+ output_pil_img = Image.open(buf)
49
+
50
+ return output_pil_img
51
+
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("# Object detection with DETR fine tuned on detection-datasets/fashionpedia")
54
+ gr.Markdown(
55
+ """
56
+ This application uses a fine tuned DETR (DEtection TRansformers) to detect objects on images.
57
+ This version was trained using detection-datasets/fashionpedia dataset.
58
+ You can load an image and see the predictions for the objects detected.
59
+ """
60
+ )
61
+
62
+ gr.Interface(
63
+ fn=detect,
64
+ inputs=gr.Image(label="Input image", type="pil"),
65
+ outputs=[
66
+ gr.Image(label="Output prediction", type="pil")
67
+ ]
68
+ )
69
+
70
+ demo.launch(show_error=True)