Spaces:
Running
Running
import os | |
import cv2 | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import AutoModel, AutoProcessor | |
from ultralytics import YOLO | |
# Custom CSS for shadcn/Radix UI inspired look | |
custom_css = """ | |
:root { | |
--primary: #0f172a; | |
--primary-foreground: #f8fafc; | |
--background: #f8fafc; | |
--card: #ffffff; | |
--card-foreground: #0f172a; | |
--border: #e2e8f0; | |
--ring: #94a3b8; | |
--radius: 0.5rem; | |
} | |
.dark { | |
--primary: #f8fafc; | |
--primary-foreground: #0f172a; | |
--background: #0f172a; | |
--card: #1e293b; | |
--card-foreground: #f8fafc; | |
--border: #334155; | |
--ring: #94a3b8; | |
} | |
.gradio-container { | |
margin: 0 !important; | |
padding: 0 !important; | |
max-width: 100% !important; | |
} | |
.main-container { | |
background-color: var(--background); | |
border-radius: var(--radius); | |
padding: 1.5rem; | |
} | |
.header { | |
margin-bottom: 1.5rem; | |
border-bottom: 1px solid var(--border); | |
padding-bottom: 1rem; | |
} | |
.header h1 { | |
font-size: 1.875rem; | |
font-weight: 700; | |
color: var(--primary); | |
margin-bottom: 0.5rem; | |
} | |
.header p { | |
color: var(--card-foreground); | |
opacity: 0.8; | |
} | |
.tab-nav { | |
background-color: var(--card); | |
border: 1px solid var(--border); | |
border-radius: var(--radius); | |
padding: 0.25rem; | |
margin-bottom: 1.5rem; | |
} | |
.tab-nav button { | |
border-radius: calc(var(--radius) - 0.25rem) !important; | |
font-weight: 500 !important; | |
transition: all 0.2s ease-in-out !important; | |
} | |
.tab-nav button.selected { | |
background-color: var(--primary) !important; | |
color: var(--primary-foreground) !important; | |
} | |
.input-panel, .output-panel { | |
background-color: var(--card); | |
border: 1px solid var(--border); | |
border-radius: var(--radius); | |
padding: 1.5rem; | |
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05); | |
} | |
.gr-button-primary { | |
background-color: var(--primary) !important; | |
color: var(--primary-foreground) !important; | |
border-radius: var(--radius) !important; | |
font-weight: 500 !important; | |
transition: all 0.2s ease-in-out !important; | |
} | |
.gr-button-primary:hover { | |
opacity: 0.9 !important; | |
} | |
.gr-form { | |
border: none !important; | |
background: transparent !important; | |
} | |
.gr-input, .gr-select { | |
border: 1px solid var(--border) !important; | |
border-radius: var(--radius) !important; | |
padding: 0.5rem 0.75rem !important; | |
} | |
.gr-panel { | |
border: none !important; | |
} | |
.footer { | |
margin-top: 1.5rem; | |
border-top: 1px solid var(--border); | |
padding-top: 1rem; | |
font-size: 0.875rem; | |
color: var(--card-foreground); | |
opacity: 0.7; | |
}""" | |
# Custom CSS for a more modern UI inspired by NextUI | |
custom_css = """ | |
:root { | |
--primary: #0070f3; | |
--primary-foreground: #ffffff; | |
--background: #f5f5f5; | |
--card: #ffffff; | |
--card-foreground: #111111; | |
--border: #eaeaea; | |
--ring: #0070f3; | |
--shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.1); | |
} | |
.dark { | |
--primary: #0070f3; | |
--primary-foreground: #ffffff; | |
--background: #000000; | |
--card: #111111; | |
--card-foreground: #ffffff; | |
--border: #333333; | |
--ring: #0070f3; | |
} | |
.gradio-container { | |
margin: 0 !important; | |
padding: 0 !important; | |
max-width: 100% !important; | |
} | |
.main-container { | |
background-color: var(--background); | |
padding: 2rem; | |
} | |
.header { | |
margin-bottom: 2rem; | |
text-align: center; | |
} | |
.header h1 { | |
font-size: 2.5rem; | |
font-weight: 800; | |
color: var(--card-foreground); | |
margin-bottom: 0.5rem; | |
background: linear-gradient(to right, #0070f3, #00bfff); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
} | |
.header p { | |
color: var(--card-foreground); | |
opacity: 0.8; | |
font-size: 1.1rem; | |
} | |
.tab-nav { | |
background-color: var(--card); | |
border-radius: var(--radius); | |
padding: 0.5rem; | |
margin-bottom: 2rem; | |
box-shadow: var(--shadow); | |
} | |
.tab-nav button { | |
border-radius: var(--radius) !important; | |
font-weight: 600 !important; | |
transition: all 0.2s ease-in-out !important; | |
padding: 0.75rem 1.5rem !important; | |
} | |
.tab-nav button.selected { | |
background-color: var(--primary) !important; | |
color: var(--primary-foreground) !important; | |
transform: translateY(-2px); | |
box-shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.25); | |
} | |
.input-panel, .output-panel { | |
background-color: var(--card); | |
border-radius: var(--radius); | |
padding: 1.5rem; | |
box-shadow: var(--shadow); | |
height: 100%; | |
display: flex; | |
flex-direction: column; | |
} | |
.input-panel h3, .output-panel h3 { | |
font-size: 1.25rem; | |
font-weight: 600; | |
margin-bottom: 1rem; | |
color: var(--card-foreground); | |
border-bottom: 2px solid var(--primary); | |
padding-bottom: 0.5rem; | |
display: inline-block; | |
} | |
.gr-button-primary { | |
background-color: var(--primary) !important; | |
color: var(--primary-foreground) !important; | |
border-radius: var(--radius) !important; | |
font-weight: 600 !important; | |
transition: all 0.2s ease-in-out !important; | |
padding: 0.75rem 1.5rem !important; | |
box-shadow: 0 4px 14px 0 rgba(0, 118, 255, 0.25) !important; | |
width: 100%; | |
margin-top: 1rem; | |
} | |
.gr-button-primary:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 6px 20px rgba(0, 118, 255, 0.35) !important; | |
} | |
.gr-form { | |
border: none !important; | |
background: transparent !important; | |
} | |
.gr-input, .gr-select { | |
border: 1px solid var(--border) !important; | |
border-radius: var(--radius) !important; | |
padding: 0.75rem 1rem !important; | |
transition: all 0.2s ease-in-out !important; | |
} | |
.gr-input:focus, .gr-select:focus { | |
border-color: var(--primary) !important; | |
box-shadow: 0 0 0 2px rgba(0, 118, 255, 0.25) !important; | |
} | |
.gr-panel { | |
border: none !important; | |
} | |
.gr-accordion { | |
border: 1px solid var(--border) !important; | |
border-radius: var(--radius) !important; | |
overflow: hidden; | |
} | |
.footer { | |
margin-top: 2rem; | |
border-top: 1px solid var(--border); | |
padding-top: 1.5rem; | |
font-size: 0.9rem; | |
color: var(--card-foreground); | |
opacity: 0.7; | |
text-align: center; | |
} | |
.footer-card { | |
background-color: var(--card); | |
border-radius: var(--radius); | |
padding: 1.5rem; | |
box-shadow: var(--shadow); | |
} | |
.tips-grid { | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
gap: 1rem; | |
margin-top: 1rem; | |
} | |
.tip-card { | |
background-color: var(--card); | |
border-radius: var(--radius); | |
padding: 1rem; | |
border-left: 3px solid var(--primary); | |
} | |
""" | |
# Available model sizes | |
DETECTION_MODELS = { | |
"small": "yolov8s-worldv2.pt", | |
"medium": "yolov8m-worldv2.pt", | |
"large": "yolov8l-worldv2.pt", | |
"xlarge": "yolov8x-worldv2.pt", | |
} | |
SEGMENTATION_MODELS = { | |
"YOLOv8 Nano": "yolov8n-seg.pt", | |
"YOLOv8 Small": "yolov8s-seg.pt", | |
"YOLOv8 Medium": "yolov8m-seg.pt", | |
"YOLOv8 Large": "yolov8l-seg.pt", | |
} | |
class YOLOWorldDetector: | |
def __init__(self, model_size="small"): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model_size = model_size | |
self.model_name = DETECTION_MODELS[model_size] | |
print(f"Loading {self.model_name} on {self.device}...") | |
try: | |
# Try to load using Ultralytics YOLOWorld | |
from ultralytics import YOLOWorld | |
self.model = YOLOWorld(self.model_name) | |
self.model_type = "yoloworld" | |
print("YOLOWorld model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading YOLOWorld model: {e}") | |
print("Falling back to standard YOLOv8 for detection...") | |
# Fallback to YOLOv8 | |
self.model = YOLO("yolov8n.pt") | |
self.model_type = "yolov8" | |
print("YOLOv8 fallback model loaded successfully!") | |
# Segmentation models | |
self.seg_models = {} | |
def change_model(self, model_size): | |
if model_size != self.model_size: | |
self.model_size = model_size | |
self.model_name = DETECTION_MODELS[model_size] | |
print(f"Loading {self.model_name} on {self.device}...") | |
try: | |
# Try to load using Ultralytics YOLOWorld | |
from ultralytics import YOLOWorld | |
self.model = YOLOWorld(self.model_name) | |
self.model_type = "yoloworld" | |
print("YOLOWorld model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading YOLOWorld model: {e}") | |
print("Falling back to standard YOLOv8 for detection...") | |
# Fallback to YOLOv8 | |
self.model = YOLO("yolov8n.pt") | |
self.model_type = "yolov8" | |
print("YOLOv8 fallback model loaded successfully!") | |
return f"Using {self.model_name} model" | |
def load_seg_model(self, model_name): | |
if model_name not in self.seg_models: | |
print(f"Loading segmentation model {model_name}...") | |
self.seg_models[model_name] = YOLO(SEGMENTATION_MODELS[model_name]) | |
print(f"Segmentation model {model_name} loaded successfully!") | |
return self.seg_models[model_name] | |
def detect(self, image, text_prompt, confidence_threshold=0.3): | |
if image is None: | |
return None, "No image provided" | |
# Process the image | |
if isinstance(image, str): | |
img_for_json = cv2.imread(image) | |
elif isinstance(image, np.ndarray): | |
img_for_json = image.copy() | |
else: | |
# Convert PIL Image to numpy array if needed | |
img_for_json = np.array(image) | |
# Run inference based on model type | |
if self.model_type == "yoloworld": | |
try: | |
# Parse text prompt properly for YOLOWorld | |
if text_prompt and text_prompt.strip(): | |
# Split by comma and strip whitespace | |
classes = [cls.strip() for cls in text_prompt.split(',') if cls.strip()] | |
else: | |
classes = None | |
self.model.set_classes(classes) | |
# YOLOWorld supports text prompts | |
results = self.model.predict( | |
source=image, | |
conf=confidence_threshold, | |
) | |
except Exception as e: | |
print(f"Error during YOLOWorld inference: {e}") | |
print("Falling back to standard YOLO inference...") | |
# If YOLOWorld inference fails, use standard YOLO | |
results = self.model.predict( | |
source=image, | |
conf=confidence_threshold, | |
verbose=False | |
) | |
else: | |
# Standard YOLO doesn't use text prompts | |
results = self.model.predict( | |
source=image, | |
conf=confidence_threshold, | |
verbose=False | |
) | |
# Get the plotted result | |
res_plotted = results[0].plot() | |
# Convert results to JSON format (percentages) | |
json_results = [] | |
img_height, img_width = img_for_json.shape[:2] | |
for i, (box, cls, conf) in enumerate(zip( | |
results[0].boxes.xyxy.cpu().numpy(), | |
results[0].boxes.cls.cpu().numpy(), | |
results[0].boxes.conf.cpu().numpy() | |
)): | |
x1, y1, x2, y2 = box | |
json_results.append({ | |
"bbox": { | |
"x": (x1 / img_width) * 100, | |
"y": (y1 / img_height) * 100, | |
"width": ((x2 - x1) / img_width) * 100, | |
"height": ((y2 - y1) / img_height) * 100 | |
}, | |
"score": float(conf), | |
"label": int(cls), | |
"label_text": results[0].names[int(cls)] | |
}) | |
return res_plotted, json_results | |
def segment(self, image, model_name, confidence_threshold=0.3): | |
if image is None: | |
return None, "No image provided" | |
# Load segmentation model if not already loaded | |
model = self.load_seg_model(model_name) | |
# Run inference | |
results = model(image, conf=confidence_threshold) | |
# Create visualization | |
fig, ax = plt.subplots(1, 1, figsize=(12, 9)) | |
ax.axis('off') | |
# Plot segmentation results | |
res_plotted = results[0].plot() | |
# Convert results to JSON format (percentages) | |
json_results = [] | |
if hasattr(results[0], 'masks') and results[0].masks is not None: | |
img_height, img_width = results[0].orig_shape | |
for i, (box, mask, cls, conf) in enumerate(zip( | |
results[0].boxes.xyxy.cpu().numpy(), | |
results[0].masks.data.cpu().numpy(), | |
results[0].boxes.cls.cpu().numpy(), | |
results[0].boxes.conf.cpu().numpy() | |
)): | |
x1, y1, x2, y2 = box | |
# Convert mask to polygon for SVG-like representation | |
# Simplified approach - in production you might want a more sophisticated polygon extraction | |
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8), | |
cv2.RETR_EXTERNAL, | |
cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
# Get the largest contour | |
largest_contour = max(contours, key=cv2.contourArea) | |
# Simplify the contour | |
epsilon = 0.005 * cv2.arcLength(largest_contour, True) | |
approx = cv2.approxPolyDP(largest_contour, epsilon, True) | |
# Convert to percentage coordinates | |
points = [] | |
for point in approx: | |
x, y = point[0] | |
points.append({ | |
"x": (x / img_width) * 100, | |
"y": (y / img_height) * 100 | |
}) | |
json_results.append({ | |
"bbox": { | |
"x": (x1 / img_width) * 100, | |
"y": (y1 / img_height) * 100, | |
"width": ((x2 - x1) / img_width) * 100, | |
"height": ((y2 - y1) / img_height) * 100 | |
}, | |
"score": float(conf), | |
"label": int(cls), | |
"label_text": results[0].names[int(cls)], | |
"polygon": points | |
}) | |
return res_plotted, json_results | |
# Initialize detector with default model | |
detector = YOLOWorldDetector(model_size="small") | |
def create_svg_from_detections(json_results, img_width, img_height): | |
"""Convert detection results to SVG format""" | |
svg_header = f'<svg width="{img_width}" height="{img_height}" xmlns="http://www.w3.org/2000/svg">' | |
svg_content = "" | |
# Color palette for different classes | |
colors = [ | |
"#FF3B30", "#FF9500", "#FFCC00", "#4CD964", | |
"#5AC8FA", "#007AFF", "#5856D6", "#FF2D55" | |
] | |
for i, result in enumerate(json_results): | |
bbox = result["bbox"] | |
label = result.get("label_text", f"Object {i}") | |
score = result.get("score", 0) | |
# Convert percentage to absolute coordinates | |
x = (bbox["x"] / 100) * img_width | |
y = (bbox["y"] / 100) * img_height | |
width = (bbox["width"] / 100) * img_width | |
height = (bbox["height"] / 100) * img_height | |
# Select color based on class index | |
color = colors[i % len(colors)] | |
# Create rectangle element | |
svg_content += f''' | |
<rect | |
x="{x:.2f}" | |
y="{y:.2f}" | |
width="{width:.2f}" | |
height="{height:.2f}" | |
stroke="{color}" | |
stroke-width="2" | |
fill="none" | |
data-label="{label}" | |
data-score="{score:.2f}" | |
/> | |
<text | |
x="{x:.2f}" | |
y="{y-5:.2f}" | |
font-family="Arial" | |
font-size="12" | |
fill="{color}" | |
>{label} ({score:.2f})</text>''' | |
svg_footer = "\n</svg>" | |
return svg_header + svg_content + svg_footer | |
def create_svg_from_segmentation(json_results, img_width, img_height): | |
"""Convert segmentation results to SVG format""" | |
svg_header = f'<svg width="{img_width}" height="{img_height}" xmlns="http://www.w3.org/2000/svg">' | |
svg_content = "" | |
# Color palette for different classes | |
colors = [ | |
"#FF3B30", "#FF9500", "#FFCC00", "#4CD964", | |
"#5AC8FA", "#007AFF", "#5856D6", "#FF2D55" | |
] | |
for i, result in enumerate(json_results): | |
label = result.get("label_text", f"Object {i}") | |
score = result.get("score", 0) | |
# Select color based on class index | |
color = colors[i % len(colors)] | |
# Create polygon if available | |
if "polygon" in result: | |
points_str = " ".join([ | |
f"{(p['x']/100)*img_width:.2f},{(p['y']/100)*img_height:.2f}" | |
for p in result["polygon"] | |
]) | |
svg_content += f''' | |
<polygon | |
points="{points_str}" | |
stroke="{color}" | |
stroke-width="2" | |
fill="{color}33" | |
data-label="{label}" | |
data-score="{score:.2f}" | |
/>''' | |
# Also add bounding box | |
bbox = result["bbox"] | |
x = (bbox["x"] / 100) * img_width | |
y = (bbox["y"] / 100) * img_height | |
width = (bbox["width"] / 100) * img_width | |
height = (bbox["height"] / 100) * img_height | |
svg_content += f''' | |
<rect | |
x="{x:.2f}" | |
y="{y:.2f}" | |
width="{width:.2f}" | |
height="{height:.2f}" | |
stroke="{color}" | |
stroke-width="1" | |
fill="none" | |
stroke-dasharray="5,5" | |
/> | |
<text | |
x="{x:.2f}" | |
y="{y-5:.2f}" | |
font-family="Arial" | |
font-size="12" | |
fill="{color}" | |
>{label} ({score:.2f})</text>''' | |
svg_footer = "\n</svg>" | |
return svg_header + svg_content + svg_footer | |
def detection_inference(image, text_prompt, confidence, model_size): | |
# Update model if needed | |
detector.change_model(model_size) | |
# Run detection | |
result_image, json_results = detector.detect( | |
image, | |
text_prompt, | |
confidence_threshold=confidence | |
) | |
# Create SVG from detection results | |
if isinstance(json_results, list) and len(json_results) > 0: | |
img_height, img_width = result_image.shape[:2] | |
svg_output = create_svg_from_detections(json_results, img_width, img_height) | |
else: | |
svg_output = "<svg></svg>" | |
return result_image, str(json_results), svg_output | |
def segmentation_inference(image, confidence, model_name): | |
# Run segmentation | |
result_image, json_results = detector.segment( | |
image, | |
model_name, | |
confidence_threshold=confidence | |
) | |
# Create SVG from segmentation results | |
if isinstance(json_results, list) and len(json_results) > 0: | |
img_height, img_width = result_image.shape[:2] | |
svg_output = create_svg_from_segmentation(json_results, img_width, img_height) | |
else: | |
svg_output = "<svg></svg>" | |
return result_image, str(json_results), svg_output | |
# Create Gradio interface | |
with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo: | |
with gr.Column(elem_classes="main-container"): | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# YOLO Vision Suite") | |
gr.Markdown("Advanced object detection and segmentation powered by YOLO models") | |
with gr.Tabs(elem_classes="tab-nav") as tabs: | |
with gr.TabItem("Object Detection", elem_id="detection-tab"): | |
with gr.Row(equal_height=True): | |
with gr.Column(elem_classes="input-panel", scale=1): | |
gr.Markdown("### Input") | |
input_image = gr.Image(label="Upload Image", type="numpy", height=300) | |
text_prompt = gr.Textbox( | |
label="Text Prompt", | |
placeholder="person, car, dog", | |
value="person, car, dog", | |
elem_classes="gr-input" | |
) | |
with gr.Row(): | |
confidence = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(DETECTION_MODELS.keys()), | |
value="small", | |
label="Model Size", | |
elem_classes="gr-select" | |
) | |
detect_button = gr.Button("Detect Objects", elem_classes="gr-button-primary") | |
with gr.Column(elem_classes="output-panel", scale=1): | |
gr.Markdown("### Results") | |
output_image = gr.Image(label="Detection Result", height=300) | |
with gr.Accordion("SVG Output", open=False, elem_classes="gr-accordion"): | |
svg_output = gr.HTML(label="SVG Visualization") | |
with gr.Accordion("JSON Output", open=False, elem_classes="gr-accordion"): | |
json_output = gr.Textbox( | |
label="Bounding Box Data (Percentage Coordinates)", | |
elem_classes="gr-input", | |
lines=5 | |
) | |
with gr.TabItem("Segmentation", elem_id="segmentation-tab"): | |
with gr.Row(equal_height=True): | |
with gr.Column(elem_classes="input-panel", scale=1): | |
gr.Markdown("### Input") | |
seg_input_image = gr.Image(label="Upload Image", type="numpy", height=300) | |
with gr.Row(): | |
seg_confidence = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
seg_model_dropdown = gr.Dropdown( | |
choices=list(SEGMENTATION_MODELS.keys()), | |
value="YOLOv8 Small", | |
label="Model Size", | |
elem_classes="gr-select" | |
) | |
segment_button = gr.Button("Segment Image", elem_classes="gr-button-primary") | |
with gr.Column(elem_classes="output-panel", scale=1): | |
gr.Markdown("### Results") | |
seg_output_image = gr.Image(label="Segmentation Result", height=300) | |
with gr.Accordion("SVG Output", open=False, elem_classes="gr-accordion"): | |
seg_svg_output = gr.HTML(label="SVG Visualization") | |
with gr.Accordion("JSON Output", open=False, elem_classes="gr-accordion"): | |
seg_json_output = gr.Textbox( | |
label="Segmentation Data (Percentage Coordinates)", | |
elem_classes="gr-input", | |
lines=5 | |
) | |
with gr.Column(elem_classes="footer"): | |
with gr.Column(elem_classes="footer-card"): | |
gr.Markdown("### Tips & Information") | |
with gr.Row(elem_classes="tips-grid"): | |
with gr.Column(elem_classes="tip-card"): | |
gr.Markdown("**Detection**") | |
gr.Markdown("Enter comma-separated text prompts to specify what objects to detect") | |
with gr.Column(elem_classes="tip-card"): | |
gr.Markdown("**Segmentation**") | |
gr.Markdown("The model will identify and segment common objects automatically") | |
with gr.Column(elem_classes="tip-card"): | |
gr.Markdown("**Models**") | |
gr.Markdown("Larger models provide better accuracy but require more processing power") | |
with gr.Column(elem_classes="tip-card"): | |
gr.Markdown("**Output**") | |
gr.Markdown("JSON output provides coordinates as percentages, compatible with SVG") | |
# Set up event handlers | |
detect_button.click( | |
detection_inference, | |
inputs=[input_image, text_prompt, confidence, model_dropdown], | |
outputs=[output_image, json_output, svg_output] | |
) | |
segment_button.click( | |
segmentation_inference, | |
inputs=[seg_input_image, seg_confidence, seg_model_dropdown], | |
outputs=[seg_output_image, seg_json_output, seg_svg_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) # Set share=True to create a public link |