Víctor Sáez commited on
Commit
2e9147d
·
1 Parent(s): 6ecfb14

Add multilenguage support

Browse files
Files changed (2) hide show
  1. app.py +321 -35
  2. requirements.txt +0 -0
app.py CHANGED
@@ -3,72 +3,358 @@ import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from pathlib import Path
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Load DETR model and processor from Hugging Face
9
- model_name = "facebook/detr-resnet-50"
10
- processor = DetrImageProcessor.from_pretrained(model_name)
11
- model = DetrForObjectDetection.from_pretrained(model_name)
12
 
13
  # Load font
14
  font_path = Path("assets/fonts/arial.ttf")
15
  if not font_path.exists():
16
- # If the font file does not exist, use the default PIL font
17
  print(f"Font file {font_path} not found. Using default font.")
18
  font = ImageFont.load_default()
19
  else:
20
- font = ImageFont.truetype(str(font_path), size=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- print(f"CUDA is available: {torch.cuda.is_available()}")
23
 
24
- # Main function: takes an image and returns it with boxes and labels
25
- def detect_objects(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  inputs = processor(images=image, return_tensors="pt")
27
  outputs = model(**inputs)
28
 
29
- # Convert model output to usable detection results
30
  target_sizes = torch.tensor([image.size[::-1]])
31
  results = processor.post_process_object_detection(
32
- outputs, threshold=0.9, target_sizes=target_sizes
33
  )[0]
34
 
35
- # Draw bounding boxes and labels on a copy of the image
36
  image_with_boxes = image.copy()
37
  draw = ImageDraw.Draw(image_with_boxes)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
 
40
  box = [round(x, 2) for x in box.tolist()]
41
- draw.rectangle(box, outline="red", width=3)
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Prepare label text
44
- label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Measure text size
47
- text_bbox = draw.textbbox((0, 0), label_text, font=font)
48
- text_width = text_bbox[2] - text_bbox[0]
49
- text_height = text_bbox[3] - text_bbox[1]
 
 
 
 
50
 
51
- # Set background rectangle for text
52
- text_background = [
53
- box[0], box[1] - text_height,
54
- box[0] + text_width, box[1]
55
  ]
56
- draw.rectangle(text_background, fill="black") # Background
57
- draw.text((box[0], box[1] - text_height), label_text, fill="white", font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- return image_with_boxes
 
 
 
 
 
60
 
 
61
 
62
- with gr.Blocks() as app:
63
- with gr.Row():
64
- gr.Markdown("## Object Detection App\nUpload an image to detect objects using Facebook's DETR model.")
65
- with gr.Row():
66
- input_image = gr.Image(type="pil", label="Input Image")
67
- output_image = gr.Image(label="Detected Objects")
68
- with gr.Row():
69
- button = gr.Button("Detect Objects")
70
 
71
- button.click(fn=detect_objects, inputs=input_image, outputs=output_image)
 
72
 
 
73
  if __name__ == "__main__":
74
- app.launch()
 
 
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from pathlib import Path
6
+ import transformers
7
 
8
+ # Global variables to cache models
9
+ current_model = None
10
+ current_processor = None
11
+ current_model_name = None
12
+
13
+ # Available models with better selection
14
+ available_models = {
15
+ # DETR Models
16
+ "DETR ResNet-50": "facebook/detr-resnet-50",
17
+ "DETR ResNet-101": "facebook/detr-resnet-101",
18
+ "DETR DC5": "facebook/detr-resnet-50-dc5",
19
+ "DETR ResNet-50 Face Only": "esraakh/detr_fine_tune_face_detection_final"
20
+ }
21
+
22
+
23
+ def load_model(model_key):
24
+ """Load model and processor based on selected model key"""
25
+ global current_model, current_processor, current_model_name
26
+
27
+ model_name = available_models[model_key]
28
+
29
+ # Only load if it's a different model
30
+ if current_model_name != model_name:
31
+ print(f"Loading model: {model_name}")
32
+ current_processor = DetrImageProcessor.from_pretrained(model_name)
33
+ current_model = DetrForObjectDetection.from_pretrained(model_name)
34
+ current_model_name = model_name
35
+ print(f"Model loaded: {model_name}")
36
+ print(f"Available labels: {list(current_model.config.id2label.values())}")
37
+
38
+ return current_model, current_processor
39
 
 
 
 
 
40
 
41
  # Load font
42
  font_path = Path("assets/fonts/arial.ttf")
43
  if not font_path.exists():
 
44
  print(f"Font file {font_path} not found. Using default font.")
45
  font = ImageFont.load_default()
46
  else:
47
+ font = ImageFont.truetype(str(font_path), size=100) # Reduced font size
48
+
49
+ # Set up translations for the app
50
+ translations = {
51
+ "English": {
52
+ "title": "## Enhanced Object Detection App\nUpload an image to detect objects using various DETR models.",
53
+ "input_label": "Input Image",
54
+ "output_label": "Detected Objects",
55
+ "dropdown_label": "Label Language",
56
+ "dropdown_detection_model_label": "Detection Model",
57
+ "threshold_label": "Detection Threshold",
58
+ "button": "Detect Objects",
59
+ "info_label": "Detection Info",
60
+ "model_fast": "General Objects (fast)",
61
+ "model_precision": "General Objects (high precision)",
62
+ "model_small": "Small Objects/Details (slow)",
63
+ "model_faces": "Face Detection (people only)"
64
+ },
65
+ "Spanish": {
66
+ "title": "## Aplicación Mejorada de Detección de Objetos\nSube una imagen para detectar objetos usando varios modelos DETR.",
67
+ "input_label": "Imagen de entrada",
68
+ "output_label": "Objetos detectados",
69
+ "dropdown_label": "Idioma de las etiquetas",
70
+ "dropdown_detection_model_label": "Modelo de detección",
71
+ "threshold_label": "Umbral de detección",
72
+ "button": "Detectar objetos",
73
+ "info_label": "Información de detección",
74
+ "model_fast": "Objetos generales (rápido)",
75
+ "model_precision": "Objetos generales (precisión alta)",
76
+ "model_small": "Objetos pequeños/detalles (lento)",
77
+ "model_faces": "Detección de caras (solo personas)"
78
+ },
79
+ "French": {
80
+ "title": "## Application Améliorée de Détection d'Objets\nTéléchargez une image pour détecter des objets avec divers modèles DETR.",
81
+ "input_label": "Image d'entrée",
82
+ "output_label": "Objets détectés",
83
+ "dropdown_label": "Langue des étiquettes",
84
+ "dropdown_detection_model_label": "Modèle de détection",
85
+ "threshold_label": "Seuil de détection",
86
+ "button": "Détecter les objets",
87
+ "info_label": "Information de détection",
88
+ "model_fast": "Objets généraux (rapide)",
89
+ "model_precision": "Objets généraux (haute précision)",
90
+ "model_small": "Petits objets/détails (lent)",
91
+ "model_faces": "Détection de visages (personnes uniquement)"
92
+ }
93
+ }
94
+
95
+
96
+ def t(language, key):
97
+ return translations.get(language, translations["English"]).get(key, key)
98
+
99
+
100
+ def get_translated_model_choices(language):
101
+ """Get model choices translated to the selected language"""
102
+ model_mapping = {
103
+ "DETR ResNet-50": "model_fast",
104
+ "DETR ResNet-101": "model_precision",
105
+ "DETR DC5": "model_small",
106
+ "DETR ResNet-50 Face Only": "model_faces"
107
+ }
108
+
109
+ translated_choices = []
110
+ for model_key in available_models.keys():
111
+ if model_key in model_mapping:
112
+ translation_key = model_mapping[model_key]
113
+ translated_name = t(language, translation_key)
114
+ else:
115
+ translated_name = model_key # Fallback to original name
116
+ translated_choices.append(translated_name)
117
+
118
+ return translated_choices
119
+
120
+
121
+ def get_model_key_from_translation(translated_name, language):
122
+ """Get the original model key from translated name"""
123
+ model_mapping = {
124
+ "DETR ResNet-50": "model_fast",
125
+ "DETR ResNet-101": "model_precision",
126
+ "DETR DC5": "model_small",
127
+ "DETR ResNet-50 Face Only": "model_faces"
128
+ }
129
+
130
+ # Reverse lookup
131
+ for model_key, translation_key in model_mapping.items():
132
+ if t(language, translation_key) == translated_name:
133
+ return model_key
134
+
135
+ # If not found, try direct match
136
+ if translated_name in available_models:
137
+ return translated_name
138
+
139
+ # Default fallback
140
+ return "DETR ResNet-50"
141
+
142
+
143
+ def get_helsinki_model(language_label):
144
+ """Returns the Helsinki-NLP model name for translating from English to the selected language."""
145
+ lang_map = {
146
+ "Spanish": "es",
147
+ "French": "fr",
148
+ "English": "en"
149
+ }
150
+ target = lang_map.get(language_label)
151
+ if not target or target == "en":
152
+ return None
153
+ return f"Helsinki-NLP/opus-mt-en-{target}"
154
+
155
+
156
+ # add cache for translations
157
+ translation_cache = {}
158
 
 
159
 
160
+ def translate_label(language_label, label):
161
+ """Translates the given label to the target language."""
162
+ # Check cache first
163
+ cache_key = f"{language_label}_{label}"
164
+ if cache_key in translation_cache:
165
+ return translation_cache[cache_key]
166
+
167
+ model_name = get_helsinki_model(language_label)
168
+ if not model_name:
169
+ return label
170
+
171
+ try:
172
+ translator = transformers.pipeline("translation", model=model_name)
173
+ result = translator(label, max_length=40)
174
+ translated = result[0]['translation_text']
175
+ # Cache the result
176
+ translation_cache[cache_key] = translated
177
+ return translated
178
+ except Exception as e:
179
+ print(f"Translation error (429 or other): {e}")
180
+ return label # Return original if translation fails
181
+
182
+
183
+ def detect_objects(image, language_selector, translated_model_selector, threshold):
184
+ """Enhanced object detection with adjustable threshold and better info"""
185
+ # Get the actual model key from the translated name
186
+ model_selector = get_model_key_from_translation(translated_model_selector, language_selector)
187
+
188
+ print(f"Processing image. Language: {language_selector}, Model: {model_selector}, Threshold: {threshold}")
189
+
190
+ # Load the selected model
191
+ model, processor = load_model(model_selector)
192
+
193
+ # Process the image
194
  inputs = processor(images=image, return_tensors="pt")
195
  outputs = model(**inputs)
196
 
197
+ # Convert model output to usable detection results with custom threshold
198
  target_sizes = torch.tensor([image.size[::-1]])
199
  results = processor.post_process_object_detection(
200
+ outputs, threshold=threshold, target_sizes=target_sizes
201
  )[0]
202
 
203
+ # Create a copy of the image for drawing
204
  image_with_boxes = image.copy()
205
  draw = ImageDraw.Draw(image_with_boxes)
206
 
207
+ # Detection info
208
+ detection_info = f"Detected {len(results['scores'])} objects with threshold {threshold}\n"
209
+ detection_info += f"Model: {translated_model_selector} ({model_selector})\n\n"
210
+
211
+ # Colors for different confidence levels
212
+ colors = {
213
+ 'high': 'red', # > 0.8
214
+ 'medium': 'orange', # 0.5-0.8
215
+ 'low': 'yellow' # < 0.5
216
+ }
217
+
218
+ detected_objects = []
219
+
220
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
221
+ confidence = score.item()
222
  box = [round(x, 2) for x in box.tolist()]
223
+
224
+ # Choose color based on confidence
225
+ if confidence > 0.8:
226
+ color = colors['high']
227
+ elif confidence > 0.5:
228
+ color = colors['medium']
229
+ else:
230
+ color = colors['low']
231
+
232
+ # Draw bounding box
233
+ draw.rectangle(box, outline=color, width=3)
234
 
235
  # Prepare label text
236
+ label_text = model.config.id2label[label.item()]
237
+ translated_label = translate_label(language_selector, label_text)
238
+ display_text = f"{translated_label}: {round(confidence, 3)}"
239
+
240
+ # Store detection info
241
+ detected_objects.append({
242
+ 'label': label_text,
243
+ 'translated': translated_label,
244
+ 'confidence': confidence,
245
+ 'box': box
246
+ })
247
 
248
+ # Calculate text position and size
249
+ try:
250
+ text_bbox = draw.textbbox((0, 0), display_text, font=font)
251
+ text_width = text_bbox[2] - text_bbox[0]
252
+ text_height = text_bbox[3] - text_bbox[1]
253
+ except:
254
+ # Fallback for older PIL versions
255
+ text_width, text_height = draw.textsize(display_text, font=font)
256
 
257
+ # Draw text background
258
+ text_bg = [
259
+ box[0], box[1] - text_height - 4,
260
+ box[0] + text_width + 4, box[1]
261
  ]
262
+ draw.rectangle(text_bg, fill="black")
263
+ draw.text((box[0] + 2, box[1] - text_height - 2), display_text, fill="white", font=font)
264
+
265
+ # Create detailed detection info
266
+ if detected_objects:
267
+ detection_info += "Objects found:\n"
268
+ for obj in sorted(detected_objects, key=lambda x: x['confidence'], reverse=True):
269
+ detection_info += f"- {obj['translated']} ({obj['label']}): {obj['confidence']:.3f}\n"
270
+ else:
271
+ detection_info += "No objects detected. Try lowering the threshold."
272
+
273
+ return image_with_boxes, detection_info
274
+
275
+
276
+ def build_app():
277
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
278
+ with gr.Row():
279
+ title = gr.Markdown(t("English", "title"))
280
+
281
+ with gr.Row():
282
+ with gr.Column(scale=1):
283
+ language_selector = gr.Dropdown(
284
+ choices=["English", "Spanish", "French"],
285
+ value="English",
286
+ label=t("English", "dropdown_label")
287
+ )
288
+ with gr.Column(scale=1):
289
+ model_selector = gr.Dropdown(
290
+ choices=get_translated_model_choices("English"),
291
+ value=t("English", "model_fast"), # Default to translated "fast" option
292
+ label=t("English", "dropdown_detection_model_label")
293
+ )
294
+ with gr.Column(scale=1):
295
+ threshold_slider = gr.Slider(
296
+ minimum=0.1,
297
+ maximum=0.95,
298
+ value=0.5, # Lowered default threshold
299
+ step=0.05,
300
+ label=t("English", "threshold_label")
301
+ )
302
+
303
+ with gr.Row():
304
+ with gr.Column(scale=1):
305
+ input_image = gr.Image(type="pil", label=t("English", "input_label"))
306
+ button = gr.Button(t("English", "button"), variant="primary")
307
+ with gr.Column(scale=1):
308
+ output_image = gr.Image(label=t("English", "output_label"))
309
+ detection_info = gr.Textbox(
310
+ label=t("English", "info_label"),
311
+ lines=10,
312
+ max_lines=15
313
+ )
314
+
315
+ # Function to update interface when language changes
316
+ def update_interface(selected_language):
317
+ translated_choices = get_translated_model_choices(selected_language)
318
+ default_model = t(selected_language, "model_fast")
319
+
320
+ return [
321
+ gr.update(value=t(selected_language, "title")),
322
+ gr.update(label=t(selected_language, "dropdown_label")),
323
+ gr.update(
324
+ choices=translated_choices,
325
+ value=default_model,
326
+ label=t(selected_language, "dropdown_detection_model_label")
327
+ ),
328
+ gr.update(label=t(selected_language, "threshold_label")),
329
+ gr.update(label=t(selected_language, "input_label")),
330
+ gr.update(value=t(selected_language, "button")),
331
+ gr.update(label=t(selected_language, "output_label")),
332
+ gr.update(label=t(selected_language, "info_label"))
333
+ ]
334
+
335
+ # Connect language change event
336
+ language_selector.change(
337
+ fn=update_interface,
338
+ inputs=language_selector,
339
+ outputs=[title, language_selector, model_selector, threshold_slider,
340
+ input_image, button, output_image, detection_info],
341
+ queue=False
342
+ )
343
 
344
+ # Connect detection button click event
345
+ button.click(
346
+ fn=detect_objects,
347
+ inputs=[input_image, language_selector, model_selector, threshold_slider],
348
+ outputs=[output_image, detection_info]
349
+ )
350
 
351
+ return app
352
 
 
 
 
 
 
 
 
 
353
 
354
+ # Initialize with default model
355
+ load_model("DETR ResNet-50")
356
 
357
+ # Launch the application
358
  if __name__ == "__main__":
359
+ app = build_app()
360
+ app.launch()
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ