bla commited on
Commit
92767db
·
verified ·
1 Parent(s): 6962f1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -25
app.py CHANGED
@@ -122,6 +122,178 @@ custom_css = """
122
  font-size: 0.875rem;
123
  color: var(--card-foreground);
124
  opacity: 0.7;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  }
126
  """
127
 
@@ -336,6 +508,119 @@ class YOLOWorldDetector:
336
  # Initialize detector with default model
337
  detector = YOLOWorldDetector(model_size="small")
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  def detection_inference(image, text_prompt, confidence, model_size):
340
  # Update model if needed
341
  detector.change_model(model_size)
@@ -347,7 +632,14 @@ def detection_inference(image, text_prompt, confidence, model_size):
347
  confidence_threshold=confidence
348
  )
349
 
350
- return result_image, str(json_results)
 
 
 
 
 
 
 
351
 
352
  def segmentation_inference(image, confidence, model_name):
353
  # Run segmentation
@@ -357,7 +649,14 @@ def segmentation_inference(image, confidence, model_name):
357
  confidence_threshold=confidence
358
  )
359
 
360
- return result_image, str(json_results)
 
 
 
 
 
 
 
361
 
362
  # Create Gradio interface
363
  with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo:
@@ -368,10 +667,10 @@ with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo:
368
 
369
  with gr.Tabs(elem_classes="tab-nav") as tabs:
370
  with gr.TabItem("Object Detection", elem_id="detection-tab"):
371
- with gr.Row():
372
- with gr.Column(elem_classes="input-panel"):
373
  gr.Markdown("### Input")
374
- input_image = gr.Image(label="Upload Image", type="numpy")
375
  text_prompt = gr.Textbox(
376
  label="Text Prompt",
377
  placeholder="person, car, dog",
@@ -394,20 +693,23 @@ with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo:
394
  )
395
  detect_button = gr.Button("Detect Objects", elem_classes="gr-button-primary")
396
 
397
- with gr.Column(elem_classes="output-panel"):
398
  gr.Markdown("### Results")
399
- output_image = gr.Image(label="Detection Result")
400
- with gr.Accordion("JSON Output", open=False):
 
 
401
  json_output = gr.Textbox(
402
  label="Bounding Box Data (Percentage Coordinates)",
403
- elem_classes="gr-input"
 
404
  )
405
 
406
  with gr.TabItem("Segmentation", elem_id="segmentation-tab"):
407
- with gr.Row():
408
- with gr.Column(elem_classes="input-panel"):
409
  gr.Markdown("### Input")
410
- seg_input_image = gr.Image(label="Upload Image", type="numpy")
411
  with gr.Row():
412
  seg_confidence = gr.Slider(
413
  minimum=0.1,
@@ -424,35 +726,49 @@ with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo:
424
  )
425
  segment_button = gr.Button("Segment Image", elem_classes="gr-button-primary")
426
 
427
- with gr.Column(elem_classes="output-panel"):
428
  gr.Markdown("### Results")
429
- seg_output_image = gr.Image(label="Segmentation Result")
430
- with gr.Accordion("JSON Output", open=False):
 
 
431
  seg_json_output = gr.Textbox(
432
  label="Segmentation Data (Percentage Coordinates)",
433
- elem_classes="gr-input"
 
434
  )
435
 
436
  with gr.Column(elem_classes="footer"):
437
- gr.Markdown("""
438
- ### Tips
439
- - For object detection, enter comma-separated text prompts to specify what to detect
440
- - For segmentation, the model will identify common objects automatically
441
- - Larger models provide better accuracy but require more processing power
442
- - The JSON output provides coordinates as percentages of image dimensions, compatible with SVG
443
- """)
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  # Set up event handlers
446
  detect_button.click(
447
  detection_inference,
448
  inputs=[input_image, text_prompt, confidence, model_dropdown],
449
- outputs=[output_image, json_output]
450
  )
451
 
452
  segment_button.click(
453
  segmentation_inference,
454
  inputs=[seg_input_image, seg_confidence, seg_model_dropdown],
455
- outputs=[seg_output_image, seg_json_output]
456
  )
457
 
458
  if __name__ == "__main__":
 
122
  font-size: 0.875rem;
123
  color: var(--card-foreground);
124
  opacity: 0.7;
125
+ }"""
126
+ # Custom CSS for a more modern UI inspired by NextUI
127
+ custom_css = """
128
+ :root {
129
+ --primary: #0070f3;
130
+ --primary-foreground: #ffffff;
131
+ --background: #f5f5f5;
132
+ --card: #ffffff;
133
+ --card-foreground: #111111;
134
+ --border: #eaeaea;
135
+ --ring: #0070f3;
136
+ --shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.1);
137
+ }
138
+
139
+ .dark {
140
+ --primary: #0070f3;
141
+ --primary-foreground: #ffffff;
142
+ --background: #000000;
143
+ --card: #111111;
144
+ --card-foreground: #ffffff;
145
+ --border: #333333;
146
+ --ring: #0070f3;
147
+ }
148
+
149
+ .gradio-container {
150
+ margin: 0 !important;
151
+ padding: 0 !important;
152
+ max-width: 100% !important;
153
+ }
154
+
155
+ .main-container {
156
+ background-color: var(--background);
157
+ padding: 2rem;
158
+ min-height: 100vh;
159
+ }
160
+
161
+ .header {
162
+ margin-bottom: 2rem;
163
+ text-align: center;
164
+ }
165
+
166
+ .header h1 {
167
+ font-size: 2.5rem;
168
+ font-weight: 800;
169
+ color: var(--card-foreground);
170
+ margin-bottom: 0.5rem;
171
+ background: linear-gradient(to right, #0070f3, #00bfff);
172
+ -webkit-background-clip: text;
173
+ -webkit-text-fill-color: transparent;
174
+ }
175
+
176
+ .header p {
177
+ color: var(--card-foreground);
178
+ opacity: 0.8;
179
+ font-size: 1.1rem;
180
+ }
181
+
182
+ .tab-nav {
183
+ background-color: var(--card);
184
+ border-radius: var(--radius);
185
+ padding: 0.5rem;
186
+ margin-bottom: 2rem;
187
+ box-shadow: var(--shadow);
188
+ }
189
+
190
+ .tab-nav button {
191
+ border-radius: var(--radius) !important;
192
+ font-weight: 600 !important;
193
+ transition: all 0.2s ease-in-out !important;
194
+ padding: 0.75rem 1.5rem !important;
195
+ }
196
+
197
+ .tab-nav button.selected {
198
+ background-color: var(--primary) !important;
199
+ color: var(--primary-foreground) !important;
200
+ transform: translateY(-2px);
201
+ box-shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.25);
202
+ }
203
+
204
+ .input-panel, .output-panel {
205
+ background-color: var(--card);
206
+ border-radius: var(--radius);
207
+ padding: 1.5rem;
208
+ box-shadow: var(--shadow);
209
+ height: 100%;
210
+ display: flex;
211
+ flex-direction: column;
212
+ }
213
+
214
+ .input-panel h3, .output-panel h3 {
215
+ font-size: 1.25rem;
216
+ font-weight: 600;
217
+ margin-bottom: 1rem;
218
+ color: var(--card-foreground);
219
+ border-bottom: 2px solid var(--primary);
220
+ padding-bottom: 0.5rem;
221
+ display: inline-block;
222
+ }
223
+
224
+ .gr-button-primary {
225
+ background-color: var(--primary) !important;
226
+ color: var(--primary-foreground) !important;
227
+ border-radius: var(--radius) !important;
228
+ font-weight: 600 !important;
229
+ transition: all 0.2s ease-in-out !important;
230
+ padding: 0.75rem 1.5rem !important;
231
+ box-shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.25) !important;
232
+ width: 100%;
233
+ margin-top: 1rem;
234
+ }
235
+
236
+ .gr-button-primary:hover {
237
+ transform: translateY(-2px) !important;
238
+ box-shadow: 0 6px 20px rgba(0, 118, 255, 0.35) !important;
239
+ }
240
+
241
+ .gr-form {
242
+ border: none !important;
243
+ background: transparent !important;
244
+ }
245
+
246
+ .gr-input, .gr-select {
247
+ border: 1px solid var(--border) !important;
248
+ border-radius: var(--radius) !important;
249
+ padding: 0.75rem 1rem !important;
250
+ transition: all 0.2s ease-in-out !important;
251
+ }
252
+
253
+ .gr-input:focus, .gr-select:focus {
254
+ border-color: var(--primary) !important;
255
+ box-shadow: 0 0 0 2px rgba(0, 118, 255, 0.25) !important;
256
+ }
257
+
258
+ .gr-panel {
259
+ border: none !important;
260
+ }
261
+
262
+ .gr-accordion {
263
+ border: 1px solid var(--border) !important;
264
+ border-radius: var(--radius) !important;
265
+ overflow: hidden;
266
+ }
267
+
268
+ .footer {
269
+ margin-top: 2rem;
270
+ border-top: 1px solid var(--border);
271
+ padding-top: 1.5rem;
272
+ font-size: 0.9rem;
273
+ color: var(--card-foreground);
274
+ opacity: 0.7;
275
+ text-align: center;
276
+ }
277
+
278
+ .footer-card {
279
+ background-color: var(--card);
280
+ border-radius: var(--radius);
281
+ padding: 1.5rem;
282
+ box-shadow: var(--shadow);
283
+ }
284
+
285
+ .tips-grid {
286
+ display: grid;
287
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
288
+ gap: 1rem;
289
+ margin-top: 1rem;
290
+ }
291
+
292
+ .tip-card {
293
+ background-color: var(--card);
294
+ border-radius: var(--radius);
295
+ padding: 1rem;
296
+ border-left: 3px solid var(--primary);
297
  }
298
  """
299
 
 
508
  # Initialize detector with default model
509
  detector = YOLOWorldDetector(model_size="small")
510
 
511
+ def create_svg_from_detections(json_results, img_width, img_height):
512
+ """Convert detection results to SVG format"""
513
+ svg_header = f'<svg width="{img_width}" height="{img_height}" xmlns="http://www.w3.org/2000/svg">'
514
+ svg_content = ""
515
+
516
+ # Color palette for different classes
517
+ colors = [
518
+ "#FF3B30", "#FF9500", "#FFCC00", "#4CD964",
519
+ "#5AC8FA", "#007AFF", "#5856D6", "#FF2D55"
520
+ ]
521
+
522
+ for i, result in enumerate(json_results):
523
+ bbox = result["bbox"]
524
+ label = result.get("label_text", f"Object {i}")
525
+ score = result.get("score", 0)
526
+
527
+ # Convert percentage to absolute coordinates
528
+ x = (bbox["x"] / 100) * img_width
529
+ y = (bbox["y"] / 100) * img_height
530
+ width = (bbox["width"] / 100) * img_width
531
+ height = (bbox["height"] / 100) * img_height
532
+
533
+ # Select color based on class index
534
+ color = colors[i % len(colors)]
535
+
536
+ # Create rectangle element
537
+ svg_content += f'''
538
+ <rect
539
+ x="{x:.2f}"
540
+ y="{y:.2f}"
541
+ width="{width:.2f}"
542
+ height="{height:.2f}"
543
+ stroke="{color}"
544
+ stroke-width="2"
545
+ fill="none"
546
+ data-label="{label}"
547
+ data-score="{score:.2f}"
548
+ />
549
+ <text
550
+ x="{x:.2f}"
551
+ y="{y-5:.2f}"
552
+ font-family="Arial"
553
+ font-size="12"
554
+ fill="{color}"
555
+ >{label} ({score:.2f})</text>'''
556
+
557
+ svg_footer = "\n</svg>"
558
+ return svg_header + svg_content + svg_footer
559
+
560
+ def create_svg_from_segmentation(json_results, img_width, img_height):
561
+ """Convert segmentation results to SVG format"""
562
+ svg_header = f'<svg width="{img_width}" height="{img_height}" xmlns="http://www.w3.org/2000/svg">'
563
+ svg_content = ""
564
+
565
+ # Color palette for different classes
566
+ colors = [
567
+ "#FF3B30", "#FF9500", "#FFCC00", "#4CD964",
568
+ "#5AC8FA", "#007AFF", "#5856D6", "#FF2D55"
569
+ ]
570
+
571
+ for i, result in enumerate(json_results):
572
+ label = result.get("label_text", f"Object {i}")
573
+ score = result.get("score", 0)
574
+
575
+ # Select color based on class index
576
+ color = colors[i % len(colors)]
577
+
578
+ # Create polygon if available
579
+ if "polygon" in result:
580
+ points_str = " ".join([
581
+ f"{(p['x']/100)*img_width:.2f},{(p['y']/100)*img_height:.2f}"
582
+ for p in result["polygon"]
583
+ ])
584
+
585
+ svg_content += f'''
586
+ <polygon
587
+ points="{points_str}"
588
+ stroke="{color}"
589
+ stroke-width="2"
590
+ fill="{color}33"
591
+ data-label="{label}"
592
+ data-score="{score:.2f}"
593
+ />'''
594
+
595
+ # Also add bounding box
596
+ bbox = result["bbox"]
597
+ x = (bbox["x"] / 100) * img_width
598
+ y = (bbox["y"] / 100) * img_height
599
+ width = (bbox["width"] / 100) * img_width
600
+ height = (bbox["height"] / 100) * img_height
601
+
602
+ svg_content += f'''
603
+ <rect
604
+ x="{x:.2f}"
605
+ y="{y:.2f}"
606
+ width="{width:.2f}"
607
+ height="{height:.2f}"
608
+ stroke="{color}"
609
+ stroke-width="1"
610
+ fill="none"
611
+ stroke-dasharray="5,5"
612
+ />
613
+ <text
614
+ x="{x:.2f}"
615
+ y="{y-5:.2f}"
616
+ font-family="Arial"
617
+ font-size="12"
618
+ fill="{color}"
619
+ >{label} ({score:.2f})</text>'''
620
+
621
+ svg_footer = "\n</svg>"
622
+ return svg_header + svg_content + svg_footer
623
+
624
  def detection_inference(image, text_prompt, confidence, model_size):
625
  # Update model if needed
626
  detector.change_model(model_size)
 
632
  confidence_threshold=confidence
633
  )
634
 
635
+ # Create SVG from detection results
636
+ if isinstance(json_results, list) and len(json_results) > 0:
637
+ img_height, img_width = result_image.shape[:2]
638
+ svg_output = create_svg_from_detections(json_results, img_width, img_height)
639
+ else:
640
+ svg_output = "<svg></svg>"
641
+
642
+ return result_image, str(json_results), svg_output
643
 
644
  def segmentation_inference(image, confidence, model_name):
645
  # Run segmentation
 
649
  confidence_threshold=confidence
650
  )
651
 
652
+ # Create SVG from segmentation results
653
+ if isinstance(json_results, list) and len(json_results) > 0:
654
+ img_height, img_width = result_image.shape[:2]
655
+ svg_output = create_svg_from_segmentation(json_results, img_width, img_height)
656
+ else:
657
+ svg_output = "<svg></svg>"
658
+
659
+ return result_image, str(json_results), svg_output
660
 
661
  # Create Gradio interface
662
  with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo:
 
667
 
668
  with gr.Tabs(elem_classes="tab-nav") as tabs:
669
  with gr.TabItem("Object Detection", elem_id="detection-tab"):
670
+ with gr.Row(equal_height=True):
671
+ with gr.Column(elem_classes="input-panel", scale=1):
672
  gr.Markdown("### Input")
673
+ input_image = gr.Image(label="Upload Image", type="numpy", height=300)
674
  text_prompt = gr.Textbox(
675
  label="Text Prompt",
676
  placeholder="person, car, dog",
 
693
  )
694
  detect_button = gr.Button("Detect Objects", elem_classes="gr-button-primary")
695
 
696
+ with gr.Column(elem_classes="output-panel", scale=1):
697
  gr.Markdown("### Results")
698
+ output_image = gr.Image(label="Detection Result", height=300)
699
+ with gr.Accordion("SVG Output", open=False, elem_classes="gr-accordion"):
700
+ svg_output = gr.HTML(label="SVG Visualization")
701
+ with gr.Accordion("JSON Output", open=False, elem_classes="gr-accordion"):
702
  json_output = gr.Textbox(
703
  label="Bounding Box Data (Percentage Coordinates)",
704
+ elem_classes="gr-input",
705
+ lines=5
706
  )
707
 
708
  with gr.TabItem("Segmentation", elem_id="segmentation-tab"):
709
+ with gr.Row(equal_height=True):
710
+ with gr.Column(elem_classes="input-panel", scale=1):
711
  gr.Markdown("### Input")
712
+ seg_input_image = gr.Image(label="Upload Image", type="numpy", height=300)
713
  with gr.Row():
714
  seg_confidence = gr.Slider(
715
  minimum=0.1,
 
726
  )
727
  segment_button = gr.Button("Segment Image", elem_classes="gr-button-primary")
728
 
729
+ with gr.Column(elem_classes="output-panel", scale=1):
730
  gr.Markdown("### Results")
731
+ seg_output_image = gr.Image(label="Segmentation Result", height=300)
732
+ with gr.Accordion("SVG Output", open=False, elem_classes="gr-accordion"):
733
+ seg_svg_output = gr.HTML(label="SVG Visualization")
734
+ with gr.Accordion("JSON Output", open=False, elem_classes="gr-accordion"):
735
  seg_json_output = gr.Textbox(
736
  label="Segmentation Data (Percentage Coordinates)",
737
+ elem_classes="gr-input",
738
+ lines=5
739
  )
740
 
741
  with gr.Column(elem_classes="footer"):
742
+ with gr.Column(elem_classes="footer-card"):
743
+ gr.Markdown("### Tips & Information")
744
+ with gr.Row(elem_classes="tips-grid"):
745
+ with gr.Column(elem_classes="tip-card"):
746
+ gr.Markdown("**Detection**")
747
+ gr.Markdown("Enter comma-separated text prompts to specify what objects to detect")
748
+
749
+ with gr.Column(elem_classes="tip-card"):
750
+ gr.Markdown("**Segmentation**")
751
+ gr.Markdown("The model will identify and segment common objects automatically")
752
+
753
+ with gr.Column(elem_classes="tip-card"):
754
+ gr.Markdown("**Models**")
755
+ gr.Markdown("Larger models provide better accuracy but require more processing power")
756
+
757
+ with gr.Column(elem_classes="tip-card"):
758
+ gr.Markdown("**Output**")
759
+ gr.Markdown("JSON output provides coordinates as percentages, compatible with SVG")
760
 
761
  # Set up event handlers
762
  detect_button.click(
763
  detection_inference,
764
  inputs=[input_image, text_prompt, confidence, model_dropdown],
765
+ outputs=[output_image, json_output, svg_output]
766
  )
767
 
768
  segment_button.click(
769
  segmentation_inference,
770
  inputs=[seg_input_image, seg_confidence, seg_model_dropdown],
771
+ outputs=[seg_output_image, seg_json_output, seg_svg_output]
772
  )
773
 
774
  if __name__ == "__main__":