hoololi commited on
Commit
9bdeee9
·
verified ·
1 Parent(s): 7a60255

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +272 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoImageProcessor, AutoModelForObjectDetection
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import torch
5
+ import spaces
6
+ import numpy as np
7
+
8
+ # Modèles disponibles sur Hugging Face Hub
9
+ AVAILABLE_MODELS = {
10
+ "DETR ResNet-50": "facebook/detr-resnet-50",
11
+ "DETR ResNet-101": "facebook/detr-resnet-101",
12
+ "Conditional DETR": "microsoft/conditional-detr-resnet-50",
13
+ "Table Transformer": "microsoft/table-transformer-detection",
14
+ "YOLOS Tiny": "hustvl/yolos-tiny",
15
+ "YOLOS Small": "hustvl/yolos-small",
16
+ "YOLOS Base": "hustvl/yolos-base",
17
+ "RT-DETR": "PekingU/rtdetr_r50vd_coco_o365",
18
+ "OWL-ViT": "google/owlvit-base-patch32"
19
+ }
20
+
21
+ # Cache pour éviter de recharger les modèles
22
+ model_cache = {}
23
+
24
+ def load_model(model_name):
25
+ """Charge un modèle avec cache"""
26
+ if model_name not in model_cache:
27
+ print(f"Chargement du modèle: {model_name}")
28
+
29
+ if "owlvit" in model_name:
30
+ # OWL-ViT est un modèle de détection zero-shot
31
+ model_cache[model_name] = pipeline(
32
+ "zero-shot-object-detection",
33
+ model=model_name,
34
+ device=0 if torch.cuda.is_available() else -1
35
+ )
36
+ else:
37
+ # Autres modèles de détection standard
38
+ model_cache[model_name] = pipeline(
39
+ "object-detection",
40
+ model=model_name,
41
+ device=0 if torch.cuda.is_available() else -1
42
+ )
43
+
44
+ return model_cache[model_name]
45
+
46
+ @spaces.GPU
47
+ def detect_objects(image, model_choice, confidence_threshold, custom_classes=""):
48
+ """Détection d'objets avec modèles transformers"""
49
+
50
+ if image is None:
51
+ return None, "❌ Veuillez uploader une image"
52
+
53
+ try:
54
+ # Charger le modèle sélectionné
55
+ model_id = AVAILABLE_MODELS[model_choice]
56
+ detector = load_model(model_id)
57
+
58
+ # Traitement spécial pour OWL-ViT (zero-shot)
59
+ if "owlvit" in model_id.lower():
60
+ if not custom_classes.strip():
61
+ custom_classes = "person, car, dog, cat, chair, table, bottle, cup"
62
+
63
+ class_list = [cls.strip() for cls in custom_classes.split(",")]
64
+ results = detector(image, candidate_labels=class_list)
65
+ else:
66
+ # Modèles de détection standard
67
+ results = detector(image)
68
+
69
+ # Filtrer par seuil de confiance
70
+ filtered_results = [
71
+ obj for obj in results
72
+ if obj['score'] >= confidence_threshold
73
+ ]
74
+
75
+ # Dessiner les détections
76
+ annotated_image = draw_detections(image.copy(), filtered_results)
77
+
78
+ # Créer le résumé
79
+ summary = create_summary(filtered_results, model_choice)
80
+
81
+ return annotated_image, summary
82
+
83
+ except Exception as e:
84
+ return image, f"❌ Erreur: {str(e)}"
85
+
86
+ def draw_detections(image, detections):
87
+ """Dessine les boîtes de détection sur l'image"""
88
+ draw = ImageDraw.Draw(image)
89
+
90
+ # Essayer de charger une police, sinon utiliser la police par défaut
91
+ try:
92
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
93
+ except:
94
+ font = ImageFont.load_default()
95
+
96
+ colors = [
97
+ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57",
98
+ "#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43"
99
+ ]
100
+
101
+ for i, detection in enumerate(detections):
102
+ box = detection['box']
103
+ label = detection['label']
104
+ score = detection['score']
105
+
106
+ # Coordonnées de la boîte
107
+ x1, y1 = box['xmin'], box['ymin']
108
+ x2, y2 = box['xmax'], box['ymax']
109
+
110
+ # Couleur pour cette classe
111
+ color = colors[i % len(colors)]
112
+
113
+ # Dessiner la boîte
114
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
115
+
116
+ # Texte du label
117
+ text = f"{label} ({score:.2f})"
118
+
119
+ # Fond du texte
120
+ bbox = draw.textbbox((x1, y1-25), text, font=font)
121
+ draw.rectangle(bbox, fill=color)
122
+
123
+ # Texte
124
+ draw.text((x1, y1-25), text, fill="white", font=font)
125
+
126
+ return image
127
+
128
+ def create_summary(detections, model_name):
129
+ """Crée un résumé des détections"""
130
+ if not detections:
131
+ return "🔍 Aucun objet détecté"
132
+
133
+ summary = f"🎯 **{len(detections)} objets détectés** avec {model_name}\n\n"
134
+
135
+ # Grouper par classe
136
+ class_counts = {}
137
+ for det in detections:
138
+ label = det['label']
139
+ score = det['score']
140
+
141
+ if label not in class_counts:
142
+ class_counts[label] = []
143
+ class_counts[label].append(score)
144
+
145
+ # Afficher le résumé
146
+ for label, scores in class_counts.items():
147
+ count = len(scores)
148
+ avg_score = sum(scores) / len(scores)
149
+ max_score = max(scores)
150
+
151
+ summary += f"**{label}**: {count}x (confiance: {avg_score:.2f} avg, {max_score:.2f} max)\n"
152
+
153
+ return summary
154
+
155
+ # Interface Gradio
156
+ with gr.Blocks(title="🤖 Object Detection avec Transformers", theme=gr.themes.Soft()) as demo:
157
+
158
+ gr.Markdown("""
159
+ # 🤖 Object Detection avec Transformers
160
+
161
+ Utilisez les meilleurs modèles de détection d'objets disponibles sur Hugging Face Hub !
162
+
163
+ **✨ Fonctionnalités:**
164
+ - 🔄 Changement de modèle en temps réel
165
+ - 🎯 Seuil de confiance ajustable
166
+ - 🏷️ Classes personnalisées (OWL-ViT)
167
+ - 📊 Résumé détaillé des détections
168
+ """)
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ # Input
173
+ image_input = gr.Image(
174
+ type="pil",
175
+ label="📸 Image à analyser",
176
+ height=400
177
+ )
178
+
179
+ # Sélection du modèle
180
+ model_dropdown = gr.Dropdown(
181
+ choices=list(AVAILABLE_MODELS.keys()),
182
+ value="DETR ResNet-50",
183
+ label="🤖 Modèle de détection",
184
+ info="Chaque modèle a ses spécialités"
185
+ )
186
+
187
+ # Paramètres
188
+ confidence_slider = gr.Slider(
189
+ minimum=0.1,
190
+ maximum=1.0,
191
+ value=0.5,
192
+ step=0.05,
193
+ label="🎯 Seuil de confiance minimum"
194
+ )
195
+
196
+ # Classes personnalisées pour OWL-ViT
197
+ custom_classes_input = gr.Textbox(
198
+ label="🏷️ Classes personnalisées (pour OWL-ViT)",
199
+ placeholder="person, car, dog, bottle, phone",
200
+ info="Séparées par des virgules. Uniquement pour OWL-ViT."
201
+ )
202
+
203
+ # Bouton de détection
204
+ detect_btn = gr.Button(
205
+ "🔍 Détecter les objets",
206
+ variant="primary",
207
+ size="lg"
208
+ )
209
+
210
+ with gr.Column(scale=1):
211
+ # Outputs
212
+ output_image = gr.Image(
213
+ label="📊 Résultats de détection",
214
+ height=400
215
+ )
216
+
217
+ detection_summary = gr.Textbox(
218
+ label="📈 Résumé des détections",
219
+ lines=8,
220
+ max_lines=15
221
+ )
222
+
223
+ # Event handlers
224
+ detect_btn.click(
225
+ fn=detect_objects,
226
+ inputs=[image_input, model_dropdown, confidence_slider, custom_classes_input],
227
+ outputs=[output_image, detection_summary]
228
+ )
229
+
230
+ # Auto-detect en changeant de modèle
231
+ model_dropdown.change(
232
+ fn=detect_objects,
233
+ inputs=[image_input, model_dropdown, confidence_slider, custom_classes_input],
234
+ outputs=[output_image, detection_summary]
235
+ )
236
+
237
+ with gr.Accordion("📚 Guide des modèles", open=False):
238
+ gr.Markdown("""
239
+ ## 🎯 Guide de sélection des modèles
240
+
241
+ ### **DETR (Detection Transformer)**
242
+ - **ResNet-50**: Équilibre vitesse/précision ⚖️
243
+ - **ResNet-101**: Plus précis, plus lent 🎯
244
+ - **Conditional DETR**: Version optimisée 🚀
245
+
246
+ ### **YOLOS (You Only Look Once Transformer)**
247
+ - **Tiny**: Ultra-rapide ⚡
248
+ - **Small**: Bon compromis 🎯
249
+ - **Base**: Maximum de précision 🔍
250
+
251
+ ### **OWL-ViT (Zero-shot Detection)**
252
+ - Détecte **n'importe quoi** que vous décrivez ! 🎨
253
+ - Tapez vos propres classes dans le champ "Classes personnalisées"
254
+
255
+ ### **RT-DETR**
256
+ - Optimisé pour le temps réel ⚡
257
+
258
+ ### **Table Transformer**
259
+ - Spécialisé dans la détection de tableaux 📊
260
+ """)
261
+
262
+ # Exemples
263
+ gr.Examples(
264
+ examples=[
265
+ ["example1.jpg", "DETR ResNet-50", 0.5, ""],
266
+ ["example2.jpg", "OWL-ViT", 0.3, "smartphone, laptop, coffee cup"],
267
+ ],
268
+ inputs=[image_input, model_dropdown, confidence_slider, custom_classes_input]
269
+ )
270
+
271
+ if __name__ == "__main__":
272
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ gradio>=5.38.2
3
+ torch
4
+ torchvision
5
+ pillow
6
+ numpy
7
+ spaces