ayoubkirouane commited on
Commit
b9b6976
·
1 Parent(s): 0efe418

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline
3
+
4
+ from PIL import Image
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+ from random import choice
9
+ import io
10
+
11
+ detector50 = pipeline(model="TuningAI/DETR-BASE_Marine")
12
+
13
+ import gradio as gr
14
+
15
+ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
16
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
17
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
18
+
19
+ fdic = {
20
+ "family" : "Impact",
21
+ "style" : "italic",
22
+ "size" : 10,
23
+ "color" : "red",
24
+ "weight" : "bold"
25
+ }
26
+
27
+
28
+ def get_figure(in_pil_img, in_results):
29
+ plt.figure(figsize=(16, 10))
30
+ plt.imshow(in_pil_img)
31
+ ax = plt.gca()
32
+
33
+ for prediction in in_results:
34
+ selected_color = choice(COLORS)
35
+
36
+ x, y = prediction['box']['xmin'], prediction['box']['ymin'],
37
+ w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
38
+
39
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
40
+ ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
41
+
42
+ plt.axis("off")
43
+
44
+ return plt.gcf()
45
+
46
+
47
+ def infer(in_pil_img):
48
+ results = detector50(in_pil_img)
49
+
50
+ figure = get_figure(in_pil_img, results)
51
+
52
+ buf = io.BytesIO()
53
+ figure.savefig(buf, bbox_inches='tight')
54
+ buf.seek(0)
55
+ output_pil_img = Image.open(buf)
56
+
57
+ return output_pil_img
58
+
59
+
60
+ with gr.Blocks(title="DETR Object Detection") as demo:
61
+ with gr.Row():
62
+ input_image = gr.Image(label="Input image", type="pil")
63
+ output_image = gr.Image(label="Output image with predicted instances", type="pil")
64
+ send_btn = gr.Button("Infer")
65
+ send_btn.click(fn=infer, inputs=input_image, outputs=[output_image])
66
+ demo.launch(debug=True)