Spaces:
Running
Running
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
from pathlib import Path | |
from dds_cloudapi_sdk import Config, Client, TextPrompt | |
from dds_cloudapi_sdk.tasks.dinox import DinoxTask | |
from dds_cloudapi_sdk.tasks.detection import DetectionTask | |
from dds_cloudapi_sdk.tasks.types import DetectionTarget | |
# Constants | |
API_TOKEN = "361d32fa5ce22649133660c65cfcaf22" | |
TEXT_PROMPT = "wheel . eye . helmet . mouse . mouth . vehicle . steering wheel . ear . nose" | |
VID_PROMPT = "wheel . mouse . pot . acquariam . box" | |
TEMP_DIR = "./temp" | |
OUTPUT_DIR = "./outputs" | |
# Ensure directories exist | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
def initialize_dino_client(): | |
"""Initialize the DINO-X client""" | |
config = Config(API_TOKEN) | |
return Client(config) | |
def get_class_mappings(text_prompt): | |
"""Create class name to ID mappings""" | |
classes = [x.strip().lower() for x in text_prompt.split('.') if x] | |
class_name_to_id = {name: id for id, name in enumerate(classes)} | |
return classes, class_name_to_id | |
def process_predictions(predictions, class_name_to_id): | |
"""Process DINO-X predictions into detection format""" | |
boxes = [] | |
masks = [] | |
confidences = [] | |
class_names = [] | |
class_ids = [] | |
for obj in predictions: | |
boxes.append(obj.bbox) | |
if hasattr(obj, 'mask') and obj.mask: | |
masks.append(DetectionTask.rle2mask( | |
DetectionTask.string2rle(obj.mask.counts), | |
obj.mask.size | |
)) | |
cls_name = obj.category.lower().strip() | |
class_names.append(cls_name) | |
class_ids.append(class_name_to_id[cls_name]) | |
confidences.append(obj.score) | |
return { | |
'boxes': np.array(boxes), | |
'masks': np.array(masks) if masks else None, | |
'class_ids': np.array(class_ids), | |
'class_names': class_names, | |
'confidences': confidences | |
} | |
def process_image(image_path, prompt=TEXT_PROMPT): | |
"""Process a single image with DINO-X""" | |
try: | |
client = initialize_dino_client() | |
_, class_name_to_id = get_class_mappings(prompt) | |
# Upload and process image | |
image_url = client.upload_file(image_path) | |
task = DinoxTask( | |
image_url=image_url, | |
prompts=[TextPrompt(text=prompt)], | |
bbox_threshold=0.25, | |
targets=[DetectionTarget.BBox, DetectionTarget.Mask] | |
) | |
client.run_task(task) | |
# Process predictions | |
results = process_predictions(task.result.objects, class_name_to_id) | |
# Annotate image | |
img = cv2.imread(image_path) | |
detections = sv.Detections( | |
xyxy=results['boxes'], | |
mask=results['masks'].astype(bool) if results['masks'] is not None else None, | |
class_id=results['class_ids'] | |
) | |
labels = [ | |
f"{name} {conf:.2f}" | |
for name, conf in zip(results['class_names'], results['confidences']) | |
] | |
# Apply annotations | |
annotator = sv.BoxAnnotator() | |
annotated_frame = annotator.annotate(scene=img.copy(), detections=detections) | |
label_annotator = sv.LabelAnnotator() | |
annotated_frame = label_annotator.annotate( | |
scene=annotated_frame, | |
detections=detections, | |
labels=labels | |
) | |
if results['masks'] is not None: | |
mask_annotator = sv.MaskAnnotator() | |
annotated_frame = mask_annotator.annotate( | |
scene=annotated_frame, | |
detections=detections | |
) | |
output_path = os.path.join(OUTPUT_DIR, "result.jpg") | |
cv2.imwrite(output_path, annotated_frame) | |
return output_path | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
def process_video(video_path, prompt=VID_PROMPT): | |
"""Process a video with DINO-X""" | |
try: | |
client = initialize_dino_client() | |
_, class_name_to_id = get_class_mappings(prompt) | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
output_path = os.path.join(OUTPUT_DIR, "result.mp4") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
frame_count = 0 | |
temp_frame_path = os.path.join(TEMP_DIR, "temp_frame.jpg") | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_count += 1 | |
if frame_count % 3 != 0: # Process every 3rd frame for speed | |
continue | |
cv2.imwrite(temp_frame_path, frame) | |
image_url = client.upload_file(temp_frame_path) | |
task = DinoxTask( | |
image_url=image_url, | |
prompts=[TextPrompt(text=prompt)], | |
bbox_threshold=0.25 | |
) | |
client.run_task(task) | |
results = process_predictions(task.result.objects, class_name_to_id) | |
detections = sv.Detections( | |
xyxy=results['boxes'], | |
class_id=results['class_ids'] | |
) | |
labels = [ | |
f"{name} {conf:.2f}" | |
for name, conf in zip(results['class_names'], results['confidences']) | |
] | |
annotator = sv.BoxAnnotator() | |
annotated_frame = annotator.annotate(scene=frame.copy(), detections=detections) | |
label_annotator = sv.LabelAnnotator() | |
annotated_frame = label_annotator.annotate( | |
scene=annotated_frame, | |
detections=detections, | |
labels=labels | |
) | |
out.write(annotated_frame) | |
cap.release() | |
out.release() | |
if os.path.exists(temp_frame_path): | |
os.remove(temp_frame_path) | |
return output_path | |
except Exception as e: | |
return f"Error processing video: {str(e)}" | |
def process_input(input_file, prompt=TEXT_PROMPT): | |
"""Process either image or video input""" | |
if input_file is None: | |
return "Please provide an input file" | |
file_path = input_file.name | |
extension = os.path.splitext(file_path)[1].lower() | |
if extension in ['.jpg', '.jpeg', '.png']: | |
return process_image(file_path, prompt) | |
elif extension in ['.mp4', '.avi', '.mov']: | |
return process_video(file_path, prompt) | |
else: | |
return "Unsupported file format. Please use jpg/jpeg/png for images or mp4/avi/mov for videos." | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.File( | |
label="Upload Image/Video", | |
file_types=["image", "video"] | |
), | |
gr.Textbox( | |
label="Detection Prompt", | |
value=TEXT_PROMPT, | |
lines=2 | |
) | |
], | |
outputs=gr.Image(label="Detection Result"), | |
title="DINO-X Object Detection", | |
description="Upload an image or video to detect objects using DINO-X. You can modify the detection prompt to specify what objects to look for.", | |
examples=[ | |
["assets/demo.png", TEXT_PROMPT], | |
["assets/demo.mp4", VID_PROMPT] | |
], | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch() |