Vineedhar commited on
Commit
8709b1e
·
verified ·
1 Parent(s): da74f70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -1
app.py CHANGED
@@ -95,7 +95,126 @@ with gr.Blocks(title="DETR Object Detection by orYx Models") as demo:
95
  output_image = gr.Image(label="Output image with predicted instances", type="pil")
96
 
97
  gr.Examples(['https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/traffic.jpg',
98
- 'https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/flyover.jpg'], inputs=input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  gr.HTML("""<br/>""")
101
  gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
 
95
  output_image = gr.Image(label="Output image with predicted instances", type="pil")
96
 
97
  gr.Examples(['https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/traffic.jpg',
98
+ 'https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/flyover.jpg'
99
+
100
+ import torch
101
+ from transformers import pipeline
102
+
103
+ from PIL import Image
104
+
105
+ import matplotlib.pyplot as plt
106
+ import matplotlib.patches as patches
107
+
108
+ from random import choice
109
+ import io
110
+
111
+ detector50 = pipeline(model="facebook/detr-resnet-50")
112
+
113
+ detector101 = pipeline(model="facebook/detr-resnet-101")
114
+
115
+
116
+ import gradio as gr
117
+
118
+ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
119
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
120
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
121
+
122
+ fdic = {
123
+ "family" : "Impact",
124
+ "style" : "italic",
125
+ "size" : 15,
126
+ "color" : "yellow",
127
+ "weight" : "bold"
128
+ }
129
+
130
+
131
+ def get_figure(in_pil_img, in_results):
132
+ plt.figure(figsize=(16, 10))
133
+ plt.imshow(in_pil_img)
134
+ #pyplot.gcf()
135
+ ax = plt.gca()
136
+
137
+ for prediction in in_results:
138
+ selected_color = choice(COLORS)
139
+
140
+ x, y = prediction['box']['xmin'], prediction['box']['ymin'],
141
+ w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
142
+
143
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
144
+ ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
145
+
146
+ plt.axis("off")
147
+
148
+ return plt.gcf()
149
+
150
+
151
+ def infer(model, in_pil_img):
152
+
153
+ results = None
154
+ if model == "detr-resnet-101":
155
+ results = detector101(in_pil_img)
156
+ else:
157
+ results = detector50(in_pil_img)
158
+
159
+ figure = get_figure(in_pil_img, results)
160
+
161
+ buf = io.BytesIO()
162
+ figure.savefig(buf, bbox_inches='tight')
163
+ buf.seek(0)
164
+ output_pil_img = Image.open(buf)
165
+
166
+ return output_pil_img
167
+
168
+
169
+ with gr.Blocks(title= "DETR Object Detection by orYx Models") as demo:
170
+ gr.HTML("""
171
+ <style>
172
+ .logo {
173
+ position: absolute;
174
+ top: 10px;
175
+ right: 10px;
176
+ width: 100px; /* Adjust the width of the logo as needed */
177
+ height: auto;
178
+ }
179
+ </style>
180
+ <div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>
181
+ <img class="logo" src="https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/oryx_logo%20(2).png" alt="Logo">
182
+ <h4 style="color:navy;">1. Select a model.</h4>
183
+ """)
184
+
185
+ model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
186
+
187
+ gr.HTML("""<br/>""")
188
+ gr.HTML("""<h4 style="color:navy;">Please upload an image by clicking on the canvas. </h4>""")
189
+
190
+ with gr.Row():
191
+ input_image = gr.Image(label="Input image", type="pil")
192
+ output_image = gr.Image(label="Output image with predicted instances", type="pil")
193
+
194
+ gr.Examples(['https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/traffic.jpg',
195
+ 'https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/blob/main/flyover.jpg',
196
+ https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/resolve/main/trees_traffic.jpg'
197
+ 'https://huggingface.co/spaces/orYx-models/object-detection-facebook-ResNets/resolve/main/Saudi_traffic.jpg'], inputs=input_image)
198
+
199
+ gr.HTML("""<br/>""")
200
+ gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
201
+
202
+ send_btn = gr.Button("Infer")
203
+ send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
204
+
205
+ gr.HTML("""<br/>""")
206
+ gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
207
+ gr.HTML("""<ul>""")
208
+ gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
209
+ gr.HTML("""</ul>""")
210
+
211
+
212
+ #demo.queue()
213
+ demo.launch(debug=True)
214
+
215
+
216
+ ### EOF ###
217
+ ], inputs=input_image)
218
 
219
  gr.HTML("""<br/>""")
220
  gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")