Spaces:
Sleeping
Sleeping
Commit
·
9d30896
1
Parent(s):
3ee85cf
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import pathlib
|
2 |
-
from constants import MODELS_REPO, MODELS_NAMES, MODELS_DETAILS
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
@@ -25,12 +24,12 @@ def make_prediction(img, feature_extractor, model):
|
|
25 |
)
|
26 |
|
27 |
|
28 |
-
def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
(
|
36 |
processed_outputs,
|
@@ -57,7 +56,6 @@ def detect_objects(model_name, image_input, threshold, display_mask=False, img_i
|
|
57 |
viz_img,
|
58 |
decoder_attention_map_img,
|
59 |
encoder_attention_map_img,
|
60 |
-
# model_details_json
|
61 |
)
|
62 |
|
63 |
|
@@ -88,14 +86,14 @@ with gr.Blocks(css=css) as app:
|
|
88 |
)
|
89 |
with gr.Column():
|
90 |
with gr.Row():
|
91 |
-
|
92 |
value=MODELS_NAMES[0],
|
93 |
choices=MODELS_NAMES,
|
94 |
label="Select an object detection model",
|
95 |
show_label=True,
|
96 |
)
|
97 |
with gr.Row():
|
98 |
-
|
99 |
value=8,
|
100 |
choices=[2, 4, 8, 16],
|
101 |
label="The number of attention heads in encoder and decoder",
|
@@ -137,13 +135,11 @@ with gr.Blocks(css=css) as app:
|
|
137 |
|
138 |
detect_button.click(
|
139 |
detect_objects,
|
140 |
-
inputs=[
|
141 |
outputs=[
|
142 |
img_output_from_upload,
|
143 |
decoder_att_map_output,
|
144 |
encoder_att_map_output,
|
145 |
-
# cross_att_map_output,
|
146 |
-
# model_details,
|
147 |
],
|
148 |
queue=True,
|
149 |
)
|
|
|
1 |
import pathlib
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
|
|
24 |
)
|
25 |
|
26 |
|
27 |
+
def detect_objects(model_name, attention_heads_num, image_input, threshold, display_mask=False, img_input_mask=None):
|
28 |
+
model_repo = f"polejowska/detr-r50-cd45rb-all-{str(attention_heads_num)}ah"
|
29 |
+
|
30 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo)
|
31 |
+
|
32 |
+
model = DetrForObjectDetection.from_pretrained(model_repo)
|
33 |
|
34 |
(
|
35 |
processed_outputs,
|
|
|
56 |
viz_img,
|
57 |
decoder_attention_map_img,
|
58 |
encoder_attention_map_img,
|
|
|
59 |
)
|
60 |
|
61 |
|
|
|
86 |
)
|
87 |
with gr.Column():
|
88 |
with gr.Row():
|
89 |
+
model_name = gr.Dropdown(
|
90 |
value=MODELS_NAMES[0],
|
91 |
choices=MODELS_NAMES,
|
92 |
label="Select an object detection model",
|
93 |
show_label=True,
|
94 |
)
|
95 |
with gr.Row():
|
96 |
+
attention_heads_num = gr.Dropdown(
|
97 |
value=8,
|
98 |
choices=[2, 4, 8, 16],
|
99 |
label="The number of attention heads in encoder and decoder",
|
|
|
135 |
|
136 |
detect_button.click(
|
137 |
detect_objects,
|
138 |
+
inputs=[model_name, attention_heads_num, img_input, slider_input, display_mask, img_input_mask],
|
139 |
outputs=[
|
140 |
img_output_from_upload,
|
141 |
decoder_att_map_output,
|
142 |
encoder_att_map_output,
|
|
|
|
|
143 |
],
|
144 |
queue=True,
|
145 |
)
|