polejowska commited on
Commit
9d30896
·
1 Parent(s): 3ee85cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import pathlib
2
- from constants import MODELS_REPO, MODELS_NAMES, MODELS_DETAILS
3
 
4
  import gradio as gr
5
  import torch
@@ -25,12 +24,12 @@ def make_prediction(img, feature_extractor, model):
25
  )
26
 
27
 
28
- def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
29
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
30
-
31
- if "DETR" in model_name:
32
- model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
33
- model_details_json = gr.Markdown(MODELS_DETAILS[model_name])
34
 
35
  (
36
  processed_outputs,
@@ -57,7 +56,6 @@ def detect_objects(model_name, image_input, threshold, display_mask=False, img_i
57
  viz_img,
58
  decoder_attention_map_img,
59
  encoder_attention_map_img,
60
- # model_details_json
61
  )
62
 
63
 
@@ -88,14 +86,14 @@ with gr.Blocks(css=css) as app:
88
  )
89
  with gr.Column():
90
  with gr.Row():
91
- options = gr.Dropdown(
92
  value=MODELS_NAMES[0],
93
  choices=MODELS_NAMES,
94
  label="Select an object detection model",
95
  show_label=True,
96
  )
97
  with gr.Row():
98
- heads = gr.Dropdown(
99
  value=8,
100
  choices=[2, 4, 8, 16],
101
  label="The number of attention heads in encoder and decoder",
@@ -137,13 +135,11 @@ with gr.Blocks(css=css) as app:
137
 
138
  detect_button.click(
139
  detect_objects,
140
- inputs=[options, img_input, slider_input, display_mask, img_input_mask],
141
  outputs=[
142
  img_output_from_upload,
143
  decoder_att_map_output,
144
  encoder_att_map_output,
145
- # cross_att_map_output,
146
- # model_details,
147
  ],
148
  queue=True,
149
  )
 
1
  import pathlib
 
2
 
3
  import gradio as gr
4
  import torch
 
24
  )
25
 
26
 
27
+ def detect_objects(model_name, attention_heads_num, image_input, threshold, display_mask=False, img_input_mask=None):
28
+ model_repo = f"polejowska/detr-r50-cd45rb-all-{str(attention_heads_num)}ah"
29
+
30
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo)
31
+
32
+ model = DetrForObjectDetection.from_pretrained(model_repo)
33
 
34
  (
35
  processed_outputs,
 
56
  viz_img,
57
  decoder_attention_map_img,
58
  encoder_attention_map_img,
 
59
  )
60
 
61
 
 
86
  )
87
  with gr.Column():
88
  with gr.Row():
89
+ model_name = gr.Dropdown(
90
  value=MODELS_NAMES[0],
91
  choices=MODELS_NAMES,
92
  label="Select an object detection model",
93
  show_label=True,
94
  )
95
  with gr.Row():
96
+ attention_heads_num = gr.Dropdown(
97
  value=8,
98
  choices=[2, 4, 8, 16],
99
  label="The number of attention heads in encoder and decoder",
 
135
 
136
  detect_button.click(
137
  detect_objects,
138
+ inputs=[model_name, attention_heads_num, img_input, slider_input, display_mask, img_input_mask],
139
  outputs=[
140
  img_output_from_upload,
141
  decoder_att_map_output,
142
  encoder_att_map_output,
 
 
143
  ],
144
  queue=True,
145
  )