Víctor Sáez commited on
Commit
c2f455f
·
1 Parent(s): 45307ff
Files changed (1) hide show
  1. app.py +20 -99
app.py CHANGED
@@ -2,49 +2,41 @@ import gradio as gr
2
  import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
- import transformers
6
 
7
- # Global variables to cache models
 
 
 
 
 
 
8
  current_model = None
9
  current_processor = None
10
  current_model_name = None
11
 
12
- # Available models with better selection
13
  available_models = {
14
- # DETR Models
15
  "DETR ResNet-50": "facebook/detr-resnet-50",
16
  "DETR ResNet-101": "facebook/detr-resnet-101",
17
  "DETR DC5": "facebook/detr-resnet-50-dc5",
18
  "DETR ResNet-50 Face Only": "esraakh/detr_fine_tune_face_detection_final"
19
  }
20
 
21
-
22
  def load_model(model_key):
23
- """Load model and processor based on selected model key"""
24
  global current_model, current_processor, current_model_name
25
-
26
  model_name = available_models[model_key]
27
-
28
- # Only load if it's a different model
29
  if current_model_name != model_name:
30
  print(f"Loading model: {model_name}")
31
  current_processor = DetrImageProcessor.from_pretrained(model_name)
32
  current_model = DetrForObjectDetection.from_pretrained(model_name)
33
  current_model_name = model_name
34
- print(f"Model loaded: {model_name}")
35
- print(f"Available labels: {list(current_model.config.id2label.values())}")
36
-
37
  return current_model, current_processor
38
 
39
-
40
- # Fixed font loading - this was the main issue
41
  def get_font(size=12):
42
  try:
43
  return ImageFont.truetype("arial.ttf", size=size)
44
  except:
45
  return ImageFont.load_default()
46
 
47
- # Set up translations for the app
48
  translations = {
49
  "English": {
50
  "title": "## Enhanced Object Detection App\nUpload an image to detect objects using various DETR models.",
@@ -90,127 +82,75 @@ translations = {
90
  }
91
  }
92
 
93
-
94
  def t(language, key):
95
  return translations.get(language, translations["English"]).get(key, key)
96
 
97
-
98
  def get_translated_model_choices(language):
99
- """Get model choices translated to the selected language"""
100
  model_mapping = {
101
  "DETR ResNet-50": "model_fast",
102
  "DETR ResNet-101": "model_precision",
103
  "DETR DC5": "model_small",
104
  "DETR ResNet-50 Face Only": "model_faces"
105
  }
106
-
107
  translated_choices = []
108
  for model_key in available_models.keys():
109
  if model_key in model_mapping:
110
  translation_key = model_mapping[model_key]
111
  translated_name = t(language, translation_key)
112
  else:
113
- translated_name = model_key # Fallback to original name
114
  translated_choices.append(translated_name)
115
-
116
  return translated_choices
117
 
118
-
119
  def get_model_key_from_translation(translated_name, language):
120
- """Get the original model key from translated name"""
121
  model_mapping = {
122
  "DETR ResNet-50": "model_fast",
123
  "DETR ResNet-101": "model_precision",
124
  "DETR DC5": "model_small",
125
  "DETR ResNet-50 Face Only": "model_faces"
126
  }
127
-
128
- # Reverse lookup
129
  for model_key, translation_key in model_mapping.items():
130
  if t(language, translation_key) == translated_name:
131
  return model_key
132
-
133
- # If not found, try direct match
134
  if translated_name in available_models:
135
  return translated_name
136
-
137
- # Default fallback
138
  return "DETR ResNet-50"
139
 
140
-
141
- def get_helsinki_model(language_label):
142
- """Returns the Helsinki-NLP model name for translating from English to the selected language."""
143
- lang_map = {
144
- "Spanish": "es",
145
- "French": "fr",
146
- "English": "en"
147
- }
148
- target = lang_map.get(language_label)
149
- if not target or target == "en":
150
- return None
151
- return f"Helsinki-NLP/opus-mt-en-{target}"
152
-
153
-
154
- # add cache for translations
155
  translation_cache = {}
156
 
157
-
158
  def translate_label(language_label, label):
159
- """Translates the given label to the target language."""
160
- # Check cache first
161
  cache_key = f"{language_label}_{label}"
162
  if cache_key in translation_cache:
163
  return translation_cache[cache_key]
164
-
165
- model_name = get_helsinki_model(language_label)
166
- if not model_name:
167
- return label
168
-
169
- try:
170
- translator = transformers.pipeline("translation", model=model_name)
171
- result = translator(label, max_length=40)
172
- translated = result[0]['translation_text']
173
- # Cache the result
174
- translation_cache[cache_key] = translated
175
- return translated
176
- except Exception as e:
177
- print(f"Translation error (429 or other): {e}")
178
- return label # Return original if translation fails
179
-
180
 
181
  def detect_objects(image, language_selector, translated_model_selector, threshold):
182
- """Enhanced object detection with adjustable threshold and better info"""
183
  try:
184
  if image is None:
185
- return None, "Por favor, sube una imagen antes de detectar objetos."
186
-
187
  model_selector = get_model_key_from_translation(translated_model_selector, language_selector)
188
- print(f"Processing image. Language: {language_selector}, Model: {model_selector}, Threshold: {threshold}")
189
-
190
  model, processor = load_model(model_selector)
191
-
192
  inputs = processor(images=image, return_tensors="pt")
193
  outputs = model(**inputs)
194
-
195
  target_sizes = torch.tensor([image.size[::-1]])
196
  results = processor.post_process_object_detection(
197
  outputs, threshold=threshold, target_sizes=target_sizes
198
  )[0]
199
-
200
  image_with_boxes = image.copy()
201
  draw = ImageDraw.Draw(image_with_boxes)
202
-
203
  detection_info = f"Detected {len(results['scores'])} objects with threshold {threshold}\n"
204
  detection_info += f"Model: {translated_model_selector} ({model_selector})\n\n"
205
-
206
  colors = {
207
- 'high': 'red', # > 0.8
208
- 'medium': 'orange', # 0.5-0.8
209
- 'low': 'yellow' # < 0.5
210
  }
211
-
212
  detected_objects = []
213
-
214
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
215
  confidence = score.item()
216
  box = [round(x, 2) for x in box.tolist()]
@@ -220,7 +160,6 @@ def detect_objects(image, language_selector, translated_model_selector, threshol
220
  color = colors['medium']
221
  else:
222
  color = colors['low']
223
-
224
  draw.rectangle(box, outline=color, width=3)
225
  label_text = model.config.id2label[label.item()]
226
  translated_label = translate_label(language_selector, label_text)
@@ -231,7 +170,6 @@ def detect_objects(image, language_selector, translated_model_selector, threshol
231
  'confidence': confidence,
232
  'box': box
233
  })
234
-
235
  try:
236
  image_width = image.size[0]
237
  font_size = max(image_width // 40, 12)
@@ -243,21 +181,18 @@ def detect_objects(image, language_selector, translated_model_selector, threshol
243
  font = get_font(12)
244
  text_width = 50
245
  text_height = 20
246
-
247
  text_bg = [
248
  box[0], box[1] - text_height - 4,
249
  box[0] + text_width + 4, box[1]
250
  ]
251
  draw.rectangle(text_bg, fill="black")
252
  draw.text((box[0] + 2, box[1] - text_height - 2), display_text, fill="white", font=font)
253
-
254
  if detected_objects:
255
  detection_info += "Objects found:\n"
256
  for obj in sorted(detected_objects, key=lambda x: x['confidence'], reverse=True):
257
  detection_info += f"- {obj['translated']} ({obj['label']}): {obj['confidence']:.3f}\n"
258
  else:
259
  detection_info += "No objects detected. Try lowering the threshold."
260
-
261
  return image_with_boxes, detection_info
262
  except Exception as e:
263
  import traceback
@@ -265,12 +200,10 @@ def detect_objects(image, language_selector, translated_model_selector, threshol
265
  traceback.print_exc()
266
  return None, f"Error detecting objects: {e}"
267
 
268
-
269
  def build_app():
270
  with gr.Blocks(theme=gr.themes.Soft()) as app:
271
  with gr.Row():
272
  title = gr.Markdown(t("English", "title"))
273
-
274
  with gr.Row():
275
  with gr.Column(scale=1):
276
  language_selector = gr.Dropdown(
@@ -281,18 +214,17 @@ def build_app():
281
  with gr.Column(scale=1):
282
  model_selector = gr.Dropdown(
283
  choices=get_translated_model_choices("English"),
284
- value=t("English", "model_fast"), # Default to translated "fast" option
285
  label=t("English", "dropdown_detection_model_label")
286
  )
287
  with gr.Column(scale=1):
288
  threshold_slider = gr.Slider(
289
  minimum=0.1,
290
  maximum=0.95,
291
- value=0.5, # Lowered default threshold
292
  step=0.05,
293
  label=t("English", "threshold_label")
294
  )
295
-
296
  with gr.Row():
297
  with gr.Column(scale=1):
298
  input_image = gr.Image(type="pil", label=t("English", "input_label"))
@@ -304,12 +236,9 @@ def build_app():
304
  lines=10,
305
  max_lines=15
306
  )
307
-
308
- # Function to update interface when language changes
309
  def update_interface(selected_language):
310
  translated_choices = get_translated_model_choices(selected_language)
311
  default_model = t(selected_language, "model_fast")
312
-
313
  return [
314
  gr.update(value=t(selected_language, "title")),
315
  gr.update(label=t(selected_language, "dropdown_label")),
@@ -324,8 +253,6 @@ def build_app():
324
  gr.update(label=t(selected_language, "output_label")),
325
  gr.update(label=t(selected_language, "info_label"))
326
  ]
327
-
328
- # Connect language change event
329
  language_selector.change(
330
  fn=update_interface,
331
  inputs=language_selector,
@@ -333,21 +260,15 @@ def build_app():
333
  input_image, button, output_image, detection_info],
334
  queue=False
335
  )
336
-
337
- # Connect detection button click event
338
  button.click(
339
  fn=detect_objects,
340
  inputs=[input_image, language_selector, model_selector, threshold_slider],
341
  outputs=[output_image, detection_info]
342
  )
343
-
344
  return app
345
 
346
-
347
- # Initialize with default model
348
  load_model("DETR ResNet-50")
349
 
350
- # Launch the application
351
  if __name__ == "__main__":
352
  app = build_app()
353
  app.launch()
 
2
  import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
 
5
 
6
+ # Only import pipeline if translation is enabled
7
+ ENABLE_TRANSLATION = False # Cambia a True solo si puedes cargar modelos Helsinki localmente
8
+
9
+ if ENABLE_TRANSLATION:
10
+ from transformers import pipeline
11
+
12
+ # Global variables
13
  current_model = None
14
  current_processor = None
15
  current_model_name = None
16
 
 
17
  available_models = {
 
18
  "DETR ResNet-50": "facebook/detr-resnet-50",
19
  "DETR ResNet-101": "facebook/detr-resnet-101",
20
  "DETR DC5": "facebook/detr-resnet-50-dc5",
21
  "DETR ResNet-50 Face Only": "esraakh/detr_fine_tune_face_detection_final"
22
  }
23
 
 
24
  def load_model(model_key):
 
25
  global current_model, current_processor, current_model_name
 
26
  model_name = available_models[model_key]
 
 
27
  if current_model_name != model_name:
28
  print(f"Loading model: {model_name}")
29
  current_processor = DetrImageProcessor.from_pretrained(model_name)
30
  current_model = DetrForObjectDetection.from_pretrained(model_name)
31
  current_model_name = model_name
 
 
 
32
  return current_model, current_processor
33
 
 
 
34
  def get_font(size=12):
35
  try:
36
  return ImageFont.truetype("arial.ttf", size=size)
37
  except:
38
  return ImageFont.load_default()
39
 
 
40
  translations = {
41
  "English": {
42
  "title": "## Enhanced Object Detection App\nUpload an image to detect objects using various DETR models.",
 
82
  }
83
  }
84
 
 
85
  def t(language, key):
86
  return translations.get(language, translations["English"]).get(key, key)
87
 
 
88
  def get_translated_model_choices(language):
 
89
  model_mapping = {
90
  "DETR ResNet-50": "model_fast",
91
  "DETR ResNet-101": "model_precision",
92
  "DETR DC5": "model_small",
93
  "DETR ResNet-50 Face Only": "model_faces"
94
  }
 
95
  translated_choices = []
96
  for model_key in available_models.keys():
97
  if model_key in model_mapping:
98
  translation_key = model_mapping[model_key]
99
  translated_name = t(language, translation_key)
100
  else:
101
+ translated_name = model_key
102
  translated_choices.append(translated_name)
 
103
  return translated_choices
104
 
 
105
  def get_model_key_from_translation(translated_name, language):
 
106
  model_mapping = {
107
  "DETR ResNet-50": "model_fast",
108
  "DETR ResNet-101": "model_precision",
109
  "DETR DC5": "model_small",
110
  "DETR ResNet-50 Face Only": "model_faces"
111
  }
 
 
112
  for model_key, translation_key in model_mapping.items():
113
  if t(language, translation_key) == translated_name:
114
  return model_key
 
 
115
  if translated_name in available_models:
116
  return translated_name
 
 
117
  return "DETR ResNet-50"
118
 
119
+ # Translation logic (only if ENABLE_TRANSLATION and model is local)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  translation_cache = {}
121
 
 
122
  def translate_label(language_label, label):
123
+ if language_label == "English" or not ENABLE_TRANSLATION:
124
+ return label
125
  cache_key = f"{language_label}_{label}"
126
  if cache_key in translation_cache:
127
  return translation_cache[cache_key]
128
+ # Dummy fallback in Spaces, or if not preloaded, just warn
129
+ translation_cache[cache_key] = f"{label} (no translation)"
130
+ return translation_cache[cache_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def detect_objects(image, language_selector, translated_model_selector, threshold):
 
133
  try:
134
  if image is None:
135
+ return None, "Please upload an image before detecting objects."
 
136
  model_selector = get_model_key_from_translation(translated_model_selector, language_selector)
 
 
137
  model, processor = load_model(model_selector)
 
138
  inputs = processor(images=image, return_tensors="pt")
139
  outputs = model(**inputs)
 
140
  target_sizes = torch.tensor([image.size[::-1]])
141
  results = processor.post_process_object_detection(
142
  outputs, threshold=threshold, target_sizes=target_sizes
143
  )[0]
 
144
  image_with_boxes = image.copy()
145
  draw = ImageDraw.Draw(image_with_boxes)
 
146
  detection_info = f"Detected {len(results['scores'])} objects with threshold {threshold}\n"
147
  detection_info += f"Model: {translated_model_selector} ({model_selector})\n\n"
 
148
  colors = {
149
+ 'high': 'red',
150
+ 'medium': 'orange',
151
+ 'low': 'yellow'
152
  }
 
153
  detected_objects = []
 
154
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
155
  confidence = score.item()
156
  box = [round(x, 2) for x in box.tolist()]
 
160
  color = colors['medium']
161
  else:
162
  color = colors['low']
 
163
  draw.rectangle(box, outline=color, width=3)
164
  label_text = model.config.id2label[label.item()]
165
  translated_label = translate_label(language_selector, label_text)
 
170
  'confidence': confidence,
171
  'box': box
172
  })
 
173
  try:
174
  image_width = image.size[0]
175
  font_size = max(image_width // 40, 12)
 
181
  font = get_font(12)
182
  text_width = 50
183
  text_height = 20
 
184
  text_bg = [
185
  box[0], box[1] - text_height - 4,
186
  box[0] + text_width + 4, box[1]
187
  ]
188
  draw.rectangle(text_bg, fill="black")
189
  draw.text((box[0] + 2, box[1] - text_height - 2), display_text, fill="white", font=font)
 
190
  if detected_objects:
191
  detection_info += "Objects found:\n"
192
  for obj in sorted(detected_objects, key=lambda x: x['confidence'], reverse=True):
193
  detection_info += f"- {obj['translated']} ({obj['label']}): {obj['confidence']:.3f}\n"
194
  else:
195
  detection_info += "No objects detected. Try lowering the threshold."
 
196
  return image_with_boxes, detection_info
197
  except Exception as e:
198
  import traceback
 
200
  traceback.print_exc()
201
  return None, f"Error detecting objects: {e}"
202
 
 
203
  def build_app():
204
  with gr.Blocks(theme=gr.themes.Soft()) as app:
205
  with gr.Row():
206
  title = gr.Markdown(t("English", "title"))
 
207
  with gr.Row():
208
  with gr.Column(scale=1):
209
  language_selector = gr.Dropdown(
 
214
  with gr.Column(scale=1):
215
  model_selector = gr.Dropdown(
216
  choices=get_translated_model_choices("English"),
217
+ value=t("English", "model_fast"),
218
  label=t("English", "dropdown_detection_model_label")
219
  )
220
  with gr.Column(scale=1):
221
  threshold_slider = gr.Slider(
222
  minimum=0.1,
223
  maximum=0.95,
224
+ value=0.5,
225
  step=0.05,
226
  label=t("English", "threshold_label")
227
  )
 
228
  with gr.Row():
229
  with gr.Column(scale=1):
230
  input_image = gr.Image(type="pil", label=t("English", "input_label"))
 
236
  lines=10,
237
  max_lines=15
238
  )
 
 
239
  def update_interface(selected_language):
240
  translated_choices = get_translated_model_choices(selected_language)
241
  default_model = t(selected_language, "model_fast")
 
242
  return [
243
  gr.update(value=t(selected_language, "title")),
244
  gr.update(label=t(selected_language, "dropdown_label")),
 
253
  gr.update(label=t(selected_language, "output_label")),
254
  gr.update(label=t(selected_language, "info_label"))
255
  ]
 
 
256
  language_selector.change(
257
  fn=update_interface,
258
  inputs=language_selector,
 
260
  input_image, button, output_image, detection_info],
261
  queue=False
262
  )
 
 
263
  button.click(
264
  fn=detect_objects,
265
  inputs=[input_image, language_selector, model_selector, threshold_slider],
266
  outputs=[output_image, detection_info]
267
  )
 
268
  return app
269
 
 
 
270
  load_model("DETR ResNet-50")
271
 
 
272
  if __name__ == "__main__":
273
  app = build_app()
274
  app.launch()