Spaces:
Running
Running
import os | |
import cv2 | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import AutoModel, AutoProcessor | |
from ultralytics import YOLO | |
# Custom CSS for shadcn/Radix UI inspired look | |
custom_css = """ | |
:root { | |
--primary: #0f172a; | |
--primary-foreground: #f8fafc; | |
--background: #f8fafc; | |
--card: #ffffff; | |
--card-foreground: #0f172a; | |
--border: #e2e8f0; | |
--ring: #94a3b8; | |
--radius: 0.5rem; | |
} | |
.dark { | |
--primary: #f8fafc; | |
--primary-foreground: #0f172a; | |
--background: #0f172a; | |
--card: #1e293b; | |
--card-foreground: #f8fafc; | |
--border: #334155; | |
--ring: #94a3b8; | |
} | |
.gradio-container { | |
margin: 0 !important; | |
padding: 0 !important; | |
max-width: 100% !important; | |
} | |
.main-container { | |
background-color: var(--background); | |
border-radius: var(--radius); | |
padding: 1.5rem; | |
} | |
.header { | |
margin-bottom: 1.5rem; | |
border-bottom: 1px solid var(--border); | |
padding-bottom: 1rem; | |
} | |
.header h1 { | |
font-size: 1.875rem; | |
font-weight: 700; | |
color: var(--primary); | |
margin-bottom: 0.5rem; | |
} | |
.header p { | |
color: var(--card-foreground); | |
opacity: 0.8; | |
} | |
.tab-nav { | |
background-color: var(--card); | |
border: 1px solid var(--border); | |
border-radius: var(--radius); | |
padding: 0.25rem; | |
margin-bottom: 1.5rem; | |
} | |
.tab-nav button { | |
border-radius: calc(var(--radius) - 0.25rem) !important; | |
font-weight: 500 !important; | |
transition: all 0.2s ease-in-out !important; | |
} | |
.tab-nav button.selected { | |
background-color: var(--primary) !important; | |
color: var(--primary-foreground) !important; | |
} | |
.input-panel, .output-panel { | |
background-color: var(--card); | |
border: 1px solid var(--border); | |
border-radius: var(--radius); | |
padding: 1.5rem; | |
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05); | |
} | |
.gr-button-primary { | |
background-color: var(--primary) !important; | |
color: var(--primary-foreground) !important; | |
border-radius: var(--radius) !important; | |
font-weight: 500 !important; | |
transition: all 0.2s ease-in-out !important; | |
} | |
.gr-button-primary:hover { | |
opacity: 0.9 !important; | |
} | |
.gr-form { | |
border: none !important; | |
background: transparent !important; | |
} | |
.gr-input, .gr-select { | |
border: 1px solid var(--border) !important; | |
border-radius: var(--radius) !important; | |
padding: 0.5rem 0.75rem !important; | |
} | |
.gr-panel { | |
border: none !important; | |
} | |
.footer { | |
margin-top: 1.5rem; | |
border-top: 1px solid var(--border); | |
padding-top: 1rem; | |
font-size: 0.875rem; | |
color: var(--card-foreground); | |
opacity: 0.7; | |
} | |
""" | |
# Available model sizes | |
DETECTION_MODELS = { | |
"tiny": "yoloworld-t", | |
"small": "yoloworld-s", | |
"base": "yoloworld-b", | |
"large": "yoloworld-l", | |
} | |
SEGMENTATION_MODELS = { | |
"YOLOv8 Nano": "yolov8n-seg.pt", | |
"YOLOv8 Small": "yolov8s-seg.pt", | |
"YOLOv8 Medium": "yolov8m-seg.pt", | |
"YOLOv8 Large": "yolov8l-seg.pt", | |
} | |
class YOLOWorldDetector: | |
def __init__(self, model_size="base"): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model_size = model_size | |
self.model_name = DETECTION_MODELS[model_size] | |
print(f"Loading {self.model_name} on {self.device}...") | |
self.model = AutoModel.from_pretrained(f"deepdatacloud/{self.model_name}", | |
trust_remote_code=True) | |
self.model.to(self.device) | |
self.processor = AutoProcessor.from_pretrained(f"deepdatacloud/{self.model_name}") | |
print("Model loaded successfully!") | |
# Segmentation models | |
self.seg_models = {} | |
def change_model(self, model_size): | |
if model_size != self.model_size: | |
self.model_size = model_size | |
self.model_name = DETECTION_MODELS[model_size] | |
print(f"Loading {self.model_name} on {self.device}...") | |
self.model = AutoModel.from_pretrained(f"deepdatacloud/{self.model_name}", | |
trust_remote_code=True) | |
self.model.to(self.device) | |
self.processor = AutoProcessor.from_pretrained(f"deepdatacloud/{self.model_name}") | |
print("Model loaded successfully!") | |
return f"Using {self.model_name} model" | |
def load_seg_model(self, model_name): | |
if model_name not in self.seg_models: | |
print(f"Loading segmentation model {model_name}...") | |
self.seg_models[model_name] = YOLO(SEGMENTATION_MODELS[model_name]) | |
print(f"Segmentation model {model_name} loaded successfully!") | |
return self.seg_models[model_name] | |
def detect(self, image, text_prompt, confidence_threshold=0.3): | |
if image is None: | |
return None, "No image provided" | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
# Process inputs | |
inputs = self.processor(text=text_prompt, images=image, return_tensors="pt") | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
# Run inference | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
# Process results | |
target_sizes = torch.tensor([image.size[::-1]], device=self.device) | |
results = self.processor.post_process_object_detection( | |
outputs=outputs, | |
target_sizes=target_sizes, | |
threshold=confidence_threshold | |
)[0] | |
# Convert image to numpy for drawing | |
image_np = np.array(image) | |
# Draw bounding boxes | |
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
box = box.cpu().numpy().astype(int) | |
score = score.cpu().item() | |
label = label.cpu().item() | |
# Get class name from model's config | |
class_name = f"{text_prompt.split(',')[label] if label < len(text_prompt.split(',')) else 'Object'}: {score:.2f}" | |
# Draw rectangle | |
cv2.rectangle( | |
image_np, | |
(box[0], box[1]), | |
(box[2], box[3]), | |
(0, 255, 0), | |
2 | |
) | |
# Draw label background | |
text_size = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] | |
cv2.rectangle( | |
image_np, | |
(box[0], box[1] - text_size[1] - 5), | |
(box[0] + text_size[0], box[1]), | |
(0, 255, 0), | |
-1 | |
) | |
# Draw text | |
cv2.putText( | |
image_np, | |
class_name, | |
(box[0], box[1] - 5), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
(0, 0, 0), | |
2 | |
) | |
# Convert results to JSON format (percentages) | |
json_results = [] | |
img_height, img_width = image_np.shape[:2] | |
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
box = box.cpu().numpy() | |
x1, y1, x2, y2 = box | |
json_results.append({ | |
"bbox": { | |
"x": (x1 / img_width) * 100, | |
"y": (y1 / img_height) * 100, | |
"width": ((x2 - x1) / img_width) * 100, | |
"height": ((y2 - y1) / img_height) * 100 | |
}, | |
"score": float(score.cpu().item()), | |
"label": int(label.cpu().item()), | |
"label_text": text_prompt.split(',')[label] if label < len(text_prompt.split(',')) else 'Object' | |
}) | |
return cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR), json_results | |
def segment(self, image, model_name, confidence_threshold=0.3): | |
if image is None: | |
return None, "No image provided" | |
# Load segmentation model if not already loaded | |
model = self.load_seg_model(model_name) | |
# Run inference | |
results = model(image, conf=confidence_threshold) | |
# Create visualization | |
fig, ax = plt.subplots(1, 1, figsize=(12, 9)) | |
ax.axis('off') | |
# Plot segmentation results | |
res_plotted = results[0].plot() | |
# Convert results to JSON format (percentages) | |
json_results = [] | |
if hasattr(results[0], 'masks') and results[0].masks is not None: | |
img_height, img_width = results[0].orig_shape | |
for i, (box, mask, cls, conf) in enumerate(zip( | |
results[0].boxes.xyxy.cpu().numpy(), | |
results[0].masks.data.cpu().numpy(), | |
results[0].boxes.cls.cpu().numpy(), | |
results[0].boxes.conf.cpu().numpy() | |
)): | |
x1, y1, x2, y2 = box | |
# Convert mask to polygon for SVG-like representation | |
# Simplified approach - in production you might want a more sophisticated polygon extraction | |
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8), | |
cv2.RETR_EXTERNAL, | |
cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
# Get the largest contour | |
largest_contour = max(contours, key=cv2.contourArea) | |
# Simplify the contour | |
epsilon = 0.005 * cv2.arcLength(largest_contour, True) | |
approx = cv2.approxPolyDP(largest_contour, epsilon, True) | |
# Convert to percentage coordinates | |
points = [] | |
for point in approx: | |
x, y = point[0] | |
points.append({ | |
"x": (x / img_width) * 100, | |
"y": (y / img_height) * 100 | |
}) | |
json_results.append({ | |
"bbox": { | |
"x": (x1 / img_width) * 100, | |
"y": (y1 / img_height) * 100, | |
"width": ((x2 - x1) / img_width) * 100, | |
"height": ((y2 - y1) / img_height) * 100 | |
}, | |
"score": float(conf), | |
"label": int(cls), | |
"label_text": results[0].names[int(cls)], | |
"polygon": points | |
}) | |
return res_plotted, json_results | |
# Initialize detector with default model | |
detector = YOLOWorldDetector(model_size="base") | |
def detection_inference(image, text_prompt, confidence, model_size): | |
# Update model if needed | |
detector.change_model(model_size) | |
# Run detection | |
result_image, json_results = detector.detect( | |
image, | |
text_prompt, | |
confidence_threshold=confidence | |
) | |
return result_image, str(json_results) | |
def segmentation_inference(image, confidence, model_name): | |
# Run segmentation | |
result_image, json_results = detector.segment( | |
image, | |
model_name, | |
confidence_threshold=confidence | |
) | |
return result_image, str(json_results) | |
# Create Gradio interface | |
with gr.Blocks(title="YOLO Vision Suite", css=custom_css) as demo: | |
with gr.Column(elem_classes="main-container"): | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# YOLO Vision Suite") | |
gr.Markdown("Advanced object detection and segmentation powered by YOLO models") | |
with gr.Tabs(elem_classes="tab-nav") as tabs: | |
with gr.TabItem("Object Detection", elem_id="detection-tab"): | |
with gr.Row(): | |
with gr.Column(elem_classes="input-panel"): | |
gr.Markdown("### Input") | |
input_image = gr.Image(label="Upload Image", type="numpy") | |
text_prompt = gr.Textbox( | |
label="Text Prompt", | |
placeholder="person, car, dog", | |
value="person, car, dog", | |
elem_classes="gr-input" | |
) | |
with gr.Row(): | |
confidence = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(DETECTION_MODELS.keys()), | |
value="base", | |
label="Model Size", | |
elem_classes="gr-select" | |
) | |
detect_button = gr.Button("Detect Objects", elem_classes="gr-button-primary") | |
with gr.Column(elem_classes="output-panel"): | |
gr.Markdown("### Results") | |
output_image = gr.Image(label="Detection Result") | |
with gr.Accordion("JSON Output", open=False): | |
json_output = gr.Textbox( | |
label="Bounding Box Data (Percentage Coordinates)", | |
elem_classes="gr-input" | |
) | |
with gr.TabItem("Segmentation", elem_id="segmentation-tab"): | |
with gr.Row(): | |
with gr.Column(elem_classes="input-panel"): | |
gr.Markdown("### Input") | |
seg_input_image = gr.Image(label="Upload Image", type="numpy") | |
with gr.Row(): | |
seg_confidence = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
seg_model_dropdown = gr.Dropdown( | |
choices=list(SEGMENTATION_MODELS.keys()), | |
value="YOLOv8 Small", | |
label="Model Size", | |
elem_classes="gr-select" | |
) | |
segment_button = gr.Button("Segment Image", elem_classes="gr-button-primary") | |
with gr.Column(elem_classes="output-panel"): | |
gr.Markdown("### Results") | |
seg_output_image = gr.Image(label="Segmentation Result") | |
with gr.Accordion("JSON Output", open=False): | |
seg_json_output = gr.Textbox( | |
label="Segmentation Data (Percentage Coordinates)", | |
elem_classes="gr-input" | |
) | |
with gr.Column(elem_classes="footer"): | |
gr.Markdown(""" | |
### Tips | |
- For object detection, enter comma-separated text prompts to specify what to detect | |
- For segmentation, the model will identify common objects automatically | |
- Larger models provide better accuracy but require more processing power | |
- The JSON output provides coordinates as percentages of image dimensions, compatible with SVG | |
""") | |
# Set up event handlers | |
detect_button.click( | |
detection_inference, | |
inputs=[input_image, text_prompt, confidence, model_dropdown], | |
outputs=[output_image, json_output] | |
) | |
segment_button.click( | |
segmentation_inference, | |
inputs=[seg_input_image, seg_confidence, seg_model_dropdown], | |
outputs=[seg_output_image, seg_json_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |