DexterSptizu commited on
Commit
c2d9ca2
·
1 Parent(s): 4cb8735

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py.txt +62 -0
  2. requirements.txt.txt +0 -0
app.py.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from random import choice
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from transformers import pipeline
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Initialize the models
9
+ detector50 = pipeline(model="facebook/detr-resnet-50")
10
+ detector101 = pipeline(model="facebook/detr-resnet-101")
11
+
12
+ # Define colors and font dictionary for bounding boxes and labels
13
+ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
14
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
15
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
16
+
17
+ fdic = {
18
+ "family": "Impact",
19
+ "style": "italic",
20
+ "size": 15,
21
+ "color": "yellow",
22
+ "weight": "bold"
23
+ }
24
+
25
+ def get_figure(in_pil_img, in_results):
26
+ # Create a figure to display the image and annotations
27
+ plt.figure(figsize=(16, 10))
28
+ plt.imshow(in_pil_img)
29
+ ax = plt.gca()
30
+
31
+ # Add bounding boxes and labels to the image
32
+ for prediction in in_results:
33
+ selected_color = choice(COLORS)
34
+ x, y = prediction['box']['xmin'], prediction['box']['ymin']
35
+ w, h = prediction['box']['xmax'] - x, prediction['box']['ymax'] - y
36
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
37
+ ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
38
+
39
+ plt.axis("off")
40
+ plt.tight_layout()
41
+
42
+ # Convert the figure to a PIL Image and return
43
+ buf = io.BytesIO()
44
+ plt.savefig(buf, format='png', bbox_inches='tight')
45
+ buf.seek(0)
46
+ return Image.open(buf)
47
+
48
+ def infer(model, in_pil_img):
49
+ # Perform inference using the specified model and input image
50
+ results = detector101(in_pil_img) if model == "detr-resnet-101" else detector50(in_pil_img)
51
+ return get_figure(in_pil_img, results)
52
+
53
+ # Define Gradio interface
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("## DETR Object Detection")
56
+ model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
57
+ input_image = gr.Image(label="Input image", type="pil")
58
+ output_image = gr.Image(label="Output image")
59
+ send_btn = gr.Button("Infer")
60
+ send_btn.click(fn=infer, inputs=[model, input_image], outputs=output_image)
61
+
62
+ demo.launch()
requirements.txt.txt ADDED
File without changes