bla commited on
Commit
9c14dee
·
verified ·
1 Parent(s): 1e3996b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +458 -0
app.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from transformers import AutoModel, AutoProcessor
9
+ from ultralytics import YOLO
10
+
11
+ # Custom CSS for shadcn/Radix UI inspired look
12
+ custom_css = """
13
+ :root {
14
+ --primary: #0f172a;
15
+ --primary-foreground: #f8fafc;
16
+ --background: #f8fafc;
17
+ --card: #ffffff;
18
+ --card-foreground: #0f172a;
19
+ --border: #e2e8f0;
20
+ --ring: #94a3b8;
21
+ --radius: 0.5rem;
22
+ }
23
+
24
+ .dark {
25
+ --primary: #f8fafc;
26
+ --primary-foreground: #0f172a;
27
+ --background: #0f172a;
28
+ --card: #1e293b;
29
+ --card-foreground: #f8fafc;
30
+ --border: #334155;
31
+ --ring: #94a3b8;
32
+ }
33
+
34
+ .gradio-container {
35
+ margin: 0 !important;
36
+ padding: 0 !important;
37
+ max-width: 100% !important;
38
+ }
39
+
40
+ .main-container {
41
+ background-color: var(--background);
42
+ border-radius: var(--radius);
43
+ padding: 1.5rem;
44
+ }
45
+
46
+ .header {
47
+ margin-bottom: 1.5rem;
48
+ border-bottom: 1px solid var(--border);
49
+ padding-bottom: 1rem;
50
+ }
51
+
52
+ .header h1 {
53
+ font-size: 1.875rem;
54
+ font-weight: 700;
55
+ color: var(--primary);
56
+ margin-bottom: 0.5rem;
57
+ }
58
+
59
+ .header p {
60
+ color: var(--card-foreground);
61
+ opacity: 0.8;
62
+ }
63
+
64
+ .tab-nav {
65
+ background-color: var(--card);
66
+ border: 1px solid var(--border);
67
+ border-radius: var(--radius);
68
+ padding: 0.25rem;
69
+ margin-bottom: 1.5rem;
70
+ }
71
+
72
+ .tab-nav button {
73
+ border-radius: calc(var(--radius) - 0.25rem) !important;
74
+ font-weight: 500 !important;
75
+ transition: all 0.2s ease-in-out !important;
76
+ }
77
+
78
+ .tab-nav button.selected {
79
+ background-color: var(--primary) !important;
80
+ color: var(--primary-foreground) !important;
81
+ }
82
+
83
+ .input-panel, .output-panel {
84
+ background-color: var(--card);
85
+ border: 1px solid var(--border);
86
+ border-radius: var(--radius);
87
+ padding: 1.5rem;
88
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
89
+ }
90
+
91
+ .gr-button-primary {
92
+ background-color: var(--primary) !important;
93
+ color: var(--primary-foreground) !important;
94
+ border-radius: var(--radius) !important;
95
+ font-weight: 500 !important;
96
+ transition: all 0.2s ease-in-out !important;
97
+ }
98
+
99
+ .gr-button-primary:hover {
100
+ opacity: 0.9 !important;
101
+ }
102
+
103
+ .gr-form {
104
+ border: none !important;
105
+ background: transparent !important;
106
+ }
107
+
108
+ .gr-input, .gr-select {
109
+ border: 1px solid var(--border) !important;
110
+ border-radius: var(--radius) !important;
111
+ padding: 0.5rem 0.75rem !important;
112
+ }
113
+
114
+ .gr-panel {
115
+ border: none !important;
116
+ }
117
+
118
+ .footer {
119
+ margin-top: 1.5rem;
120
+ border-top: 1px solid var(--border);
121
+ padding-top: 1rem;
122
+ font-size: 0.875rem;
123
+ color: var(--card-foreground);
124
+ opacity: 0.7;
125
+ }
126
+ """
127
+
128
+ # Available model sizes
129
+ DETECTION_MODELS = {
130
+ "tiny": "yoloworld-t",
131
+ "small": "yoloworld-s",
132
+ "base": "yoloworld-b",
133
+ "large": "yoloworld-l",
134
+ }
135
+
136
+ SEGMENTATION_MODELS = {
137
+ "YOLOv8 Nano": "yolov8n-seg.pt",
138
+ "YOLOv8 Small": "yolov8s-seg.pt",
139
+ "YOLOv8 Medium": "yolov8m-seg.pt",
140
+ "YOLOv8 Large": "yolov8l-seg.pt",
141
+ }
142
+
143
+ class YOLOWorldDetector:
144
+ def __init__(self, model_size="base"):
145
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
146
+ self.model_size = model_size
147
+ self.model_name = DETECTION_MODELS[model_size]
148
+
149
+ print(f"Loading {self.model_name} on {self.device}...")
150
+ self.model = AutoModel.from_pretrained(f"deepdatacloud/{self.model_name}",
151
+ trust_remote_code=True)
152
+ self.model.to(self.device)
153
+ self.processor = AutoProcessor.from_pretrained(f"deepdatacloud/{self.model_name}")
154
+ print("Model loaded successfully!")
155
+
156
+ # Segmentation models
157
+ self.seg_models = {}
158
+
159
+ def change_model(self, model_size):
160
+ if model_size != self.model_size:
161
+ self.model_size = model_size
162
+ self.model_name = DETECTION_MODELS[model_size]
163
+
164
+ print(f"Loading {self.model_name} on {self.device}...")
165
+ self.model = AutoModel.from_pretrained(f"deepdatacloud/{self.model_name}",
166
+ trust_remote_code=True)
167
+ self.model.to(self.device)
168
+ self.processor = AutoProcessor.from_pretrained(f"deepdatacloud/{self.model_name}")
169
+ print("Model loaded successfully!")
170
+ return f"Using {self.model_name} model"
171
+
172
+ def load_seg_model(self, model_name):
173
+ if model_name not in self.seg_models:
174
+ print(f"Loading segmentation model {model_name}...")
175
+ self.seg_models[model_name] = YOLO(SEGMENTATION_MODELS[model_name])
176
+ print(f"Segmentation model {model_name} loaded successfully!")
177
+ return self.seg_models[model_name]
178
+
179
+ def detect(self, image, text_prompt, confidence_threshold=0.3):
180
+ if image is None:
181
+ return None, "No image provided"
182
+
183
+ if isinstance(image, str):
184
+ image = Image.open(image).convert("RGB")
185
+ elif isinstance(image, np.ndarray):
186
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
187
+
188
+ # Process inputs
189
+ inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
190
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
191
+
192
+ # Run inference
193
+ with torch.no_grad():
194
+ outputs = self.model(**inputs)
195
+
196
+ # Process results
197
+ target_sizes = torch.tensor([image.size[::-1]], device=self.device)
198
+ results = self.processor.post_process_object_detection(
199
+ outputs=outputs,
200
+ target_sizes=target_sizes,
201
+ threshold=confidence_threshold
202
+ )[0]
203
+
204
+ # Convert image to numpy for drawing
205
+ image_np = np.array(image)
206
+
207
+ # Draw bounding boxes
208
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
209
+ box = box.cpu().numpy().astype(int)
210
+ score = score.cpu().item()
211
+ label = label.cpu().item()
212
+
213
+ # Get class name from model's config
214
+ class_name = f"{text_prompt.split(',')[label] if label < len(text_prompt.split(',')) else 'Object'}: {score:.2f}"
215
+
216
+ # Draw rectangle
217
+ cv2.rectangle(
218
+ image_np,
219
+ (box[0], box[1]),
220
+ (box[2], box[3]),
221
+ (0, 255, 0),
222
+ 2
223
+ )
224
+
225
+ # Draw label background
226
+ text_size = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
227
+ cv2.rectangle(
228
+ image_np,
229
+ (box[0], box[1] - text_size[1] - 5),
230
+ (box[0] + text_size[0], box[1]),
231
+ (0, 255, 0),
232
+ -1
233
+ )
234
+
235
+ # Draw text
236
+ cv2.putText(
237
+ image_np,
238
+ class_name,
239
+ (box[0], box[1] - 5),
240
+ cv2.FONT_HERSHEY_SIMPLEX,
241
+ 0.5,
242
+ (0, 0, 0),
243
+ 2
244
+ )
245
+
246
+ # Convert results to JSON format (percentages)
247
+ json_results = []
248
+ img_height, img_width = image_np.shape[:2]
249
+
250
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
251
+ box = box.cpu().numpy()
252
+ x1, y1, x2, y2 = box
253
+
254
+ json_results.append({
255
+ "bbox": {
256
+ "x": (x1 / img_width) * 100,
257
+ "y": (y1 / img_height) * 100,
258
+ "width": ((x2 - x1) / img_width) * 100,
259
+ "height": ((y2 - y1) / img_height) * 100
260
+ },
261
+ "score": float(score.cpu().item()),
262
+ "label": int(label.cpu().item()),
263
+ "label_text": text_prompt.split(',')[label] if label < len(text_prompt.split(',')) else 'Object'
264
+ })
265
+
266
+ return cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR), json_results
267
+
268
+ def segment(self, image, model_name, confidence_threshold=0.3):
269
+ if image is None:
270
+ return None, "No image provided"
271
+
272
+ # Load segmentation model if not already loaded
273
+ model = self.load_seg_model(model_name)
274
+
275
+ # Run inference
276
+ results = model(image, conf=confidence_threshold)
277
+
278
+ # Create visualization
279
+ fig, ax = plt.subplots(1, 1, figsize=(12, 9))
280
+ ax.axis('off')
281
+
282
+ # Plot segmentation results
283
+ res_plotted = results[0].plot()
284
+
285
+ # Convert results to JSON format (percentages)
286
+ json_results = []
287
+ if hasattr(results[0], 'masks') and results[0].masks is not None:
288
+ img_height, img_width = results[0].orig_shape
289
+
290
+ for i, (box, mask, cls, conf) in enumerate(zip(
291
+ results[0].boxes.xyxy.cpu().numpy(),
292
+ results[0].masks.data.cpu().numpy(),
293
+ results[0].boxes.cls.cpu().numpy(),
294
+ results[0].boxes.conf.cpu().numpy()
295
+ )):
296
+ x1, y1, x2, y2 = box
297
+
298
+ # Convert mask to polygon for SVG-like representation
299
+ # Simplified approach - in production you might want a more sophisticated polygon extraction
300
+ contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8),
301
+ cv2.RETR_EXTERNAL,
302
+ cv2.CHAIN_APPROX_SIMPLE)
303
+
304
+ if contours:
305
+ # Get the largest contour
306
+ largest_contour = max(contours, key=cv2.contourArea)
307
+ # Simplify the contour
308
+ epsilon = 0.005 * cv2.arcLength(largest_contour, True)
309
+ approx = cv2.approxPolyDP(largest_contour, epsilon, True)
310
+
311
+ # Convert to percentage coordinates
312
+ points = []
313
+ for point in approx:
314
+ x, y = point[0]
315
+ points.append({
316
+ "x": (x / img_width) * 100,
317
+ "y": (y / img_height) * 100
318
+ })
319
+
320
+ json_results.append({
321
+ "bbox": {
322
+ "x": (x1 / img_width) * 100,
323
+ "y": (y1 / img_height) * 100,
324
+ "width": ((x2 - x1) / img_width) * 100,
325
+ "height": ((y2 - y1) / img_height) * 100
326
+ },
327
+ "score": float(conf),
328
+ "label": int(cls),
329
+ "label_text": results[0].names[int(cls)],
330
+ "polygon": points
331
+ })
332
+
333
+ return res_plotted, json_results
334
+
335
+ # Initialize detector with default model
336
+ detector = YOLOWorldDetector(model_size="base")
337
+
338
+ def detection_inference(image, text_prompt, confidence, model_size):
339
+ # Update model if needed
340
+ detector.change_model(model_size)
341
+
342
+ # Run detection
343
+ result_image, json_results = detector.detect(
344
+ image,
345
+ text_prompt,
346
+ confidence_threshold=confidence
347
+ )
348
+
349
+ return result_image, str(json_results)
350
+
351
+ def segmentation_inference(image, confidence, model_name):
352
+ # Run segmentation
353
+ result_image, json_results = detector.segment(
354
+ image,
355
+ model_name,
356
+ confidence_threshold=confidence
357
+ )
358
+
359
+ return result_image, str(json_results)
360
+
361
+ # Create Gradio interface
362
+ with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo:
363
+ with gr.Column(elem_classes="main-container"):
364
+ with gr.Column(elem_classes="header"):
365
+ gr.Markdown("# YOLO Vision Suite")
366
+ gr.Markdown("Advanced object detection and segmentation powered by YOLO models")
367
+
368
+ with gr.Tabs(elem_classes="tab-nav") as tabs:
369
+ with gr.TabItem("Object Detection", elem_id="detection-tab"):
370
+ with gr.Row():
371
+ with gr.Column(elem_classes="input-panel"):
372
+ gr.Markdown("### Input")
373
+ input_image = gr.Image(label="Upload Image", type="numpy")
374
+ text_prompt = gr.Textbox(
375
+ label="Text Prompt",
376
+ placeholder="person, car, dog",
377
+ value="person, car, dog",
378
+ elem_classes="gr-input"
379
+ )
380
+ with gr.Row():
381
+ confidence = gr.Slider(
382
+ minimum=0.1,
383
+ maximum=1.0,
384
+ value=0.3,
385
+ step=0.05,
386
+ label="Confidence Threshold"
387
+ )
388
+ model_dropdown = gr.Dropdown(
389
+ choices=list(DETECTION_MODELS.keys()),
390
+ value="base",
391
+ label="Model Size",
392
+ elem_classes="gr-select"
393
+ )
394
+ detect_button = gr.Button("Detect Objects", elem_classes="gr-button-primary")
395
+
396
+ with gr.Column(elem_classes="output-panel"):
397
+ gr.Markdown("### Results")
398
+ output_image = gr.Image(label="Detection Result")
399
+ with gr.Accordion("JSON Output", open=False):
400
+ json_output = gr.Textbox(
401
+ label="Bounding Box Data (Percentage Coordinates)",
402
+ elem_classes="gr-input"
403
+ )
404
+
405
+ with gr.TabItem("Segmentation", elem_id="segmentation-tab"):
406
+ with gr.Row():
407
+ with gr.Column(elem_classes="input-panel"):
408
+ gr.Markdown("### Input")
409
+ seg_input_image = gr.Image(label="Upload Image", type="numpy")
410
+ with gr.Row():
411
+ seg_confidence = gr.Slider(
412
+ minimum=0.1,
413
+ maximum=1.0,
414
+ value=0.3,
415
+ step=0.05,
416
+ label="Confidence Threshold"
417
+ )
418
+ seg_model_dropdown = gr.Dropdown(
419
+ choices=list(SEGMENTATION_MODELS.keys()),
420
+ value="YOLOv8 Small",
421
+ label="Model Size",
422
+ elem_classes="gr-select"
423
+ )
424
+ segment_button = gr.Button("Segment Image", elem_classes="gr-button-primary")
425
+
426
+ with gr.Column(elem_classes="output-panel"):
427
+ gr.Markdown("### Results")
428
+ seg_output_image = gr.Image(label="Segmentation Result")
429
+ with gr.Accordion("JSON Output", open=False):
430
+ seg_json_output = gr.Textbox(
431
+ label="Segmentation Data (Percentage Coordinates)",
432
+ elem_classes="gr-input"
433
+ )
434
+
435
+ with gr.Column(elem_classes="footer"):
436
+ gr.Markdown("""
437
+ ### Tips
438
+ - For object detection, enter comma-separated text prompts to specify what to detect
439
+ - For segmentation, the model will identify common objects automatically
440
+ - Larger models provide better accuracy but require more processing power
441
+ - The JSON output provides coordinates as percentages of image dimensions, compatible with SVG
442
+ """)
443
+
444
+ # Set up event handlers
445
+ detect_button.click(
446
+ detection_inference,
447
+ inputs=[input_image, text_prompt, confidence, model_dropdown],
448
+ outputs=[output_image, json_output]
449
+ )
450
+
451
+ segment_button.click(
452
+ segmentation_inference,
453
+ inputs=[seg_input_image, seg_confidence, seg_model_dropdown],
454
+ outputs=[seg_output_image, seg_json_output]
455
+ )
456
+
457
+ if __name__ == "__main__":
458
+ demo.launch()