Dino-X-API-Demo / app.py
minar09's picture
Update app.py
cc681d9 verified
import os
import cv2
import gradio as gr
import numpy as np
import supervision as sv
from pathlib import Path
from dds_cloudapi_sdk import Config, Client, TextPrompt
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
from dds_cloudapi_sdk.tasks.detection import DetectionTask
from dds_cloudapi_sdk.tasks.types import DetectionTarget
# Constants
API_TOKEN = "361d32fa5ce22649133660c65cfcaf22"
TEXT_PROMPT = "wheel . eye . helmet . mouse . mouth . vehicle . steering wheel . ear . nose"
VID_PROMPT = "wheel . mouse . pot . acquariam . box"
TEMP_DIR = "./temp"
OUTPUT_DIR = "./outputs"
# Ensure directories exist
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
def initialize_dino_client():
"""Initialize the DINO-X client"""
config = Config(API_TOKEN)
return Client(config)
def get_class_mappings(text_prompt):
"""Create class name to ID mappings"""
classes = [x.strip().lower() for x in text_prompt.split('.') if x]
class_name_to_id = {name: id for id, name in enumerate(classes)}
return classes, class_name_to_id
def process_predictions(predictions, class_name_to_id):
"""Process DINO-X predictions into detection format"""
boxes = []
masks = []
confidences = []
class_names = []
class_ids = []
for obj in predictions:
boxes.append(obj.bbox)
if hasattr(obj, 'mask') and obj.mask:
masks.append(DetectionTask.rle2mask(
DetectionTask.string2rle(obj.mask.counts),
obj.mask.size
))
cls_name = obj.category.lower().strip()
class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name])
confidences.append(obj.score)
return {
'boxes': np.array(boxes),
'masks': np.array(masks) if masks else None,
'class_ids': np.array(class_ids),
'class_names': class_names,
'confidences': confidences
}
def process_image(image_path, prompt=TEXT_PROMPT):
"""Process a single image with DINO-X"""
try:
client = initialize_dino_client()
_, class_name_to_id = get_class_mappings(prompt)
# Upload and process image
image_url = client.upload_file(image_path)
task = DinoxTask(
image_url=image_url,
prompts=[TextPrompt(text=prompt)],
bbox_threshold=0.25,
targets=[DetectionTarget.BBox, DetectionTarget.Mask]
)
client.run_task(task)
# Process predictions
results = process_predictions(task.result.objects, class_name_to_id)
# Annotate image
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=results['boxes'],
mask=results['masks'].astype(bool) if results['masks'] is not None else None,
class_id=results['class_ids']
)
labels = [
f"{name} {conf:.2f}"
for name, conf in zip(results['class_names'], results['confidences'])
]
# Apply annotations
annotator = sv.BoxAnnotator()
annotated_frame = annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(
scene=annotated_frame,
detections=detections,
labels=labels
)
if results['masks'] is not None:
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(
scene=annotated_frame,
detections=detections
)
output_path = os.path.join(OUTPUT_DIR, "result.jpg")
cv2.imwrite(output_path, annotated_frame)
return output_path
except Exception as e:
return f"Error processing image: {str(e)}"
def process_video(video_path, prompt=VID_PROMPT):
"""Process a video with DINO-X"""
try:
client = initialize_dino_client()
_, class_name_to_id = get_class_mappings(prompt)
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
output_path = os.path.join(OUTPUT_DIR, "result.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
temp_frame_path = os.path.join(TEMP_DIR, "temp_frame.jpg")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % 3 != 0: # Process every 3rd frame for speed
continue
cv2.imwrite(temp_frame_path, frame)
image_url = client.upload_file(temp_frame_path)
task = DinoxTask(
image_url=image_url,
prompts=[TextPrompt(text=prompt)],
bbox_threshold=0.25
)
client.run_task(task)
results = process_predictions(task.result.objects, class_name_to_id)
detections = sv.Detections(
xyxy=results['boxes'],
class_id=results['class_ids']
)
labels = [
f"{name} {conf:.2f}"
for name, conf in zip(results['class_names'], results['confidences'])
]
annotator = sv.BoxAnnotator()
annotated_frame = annotator.annotate(scene=frame.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(
scene=annotated_frame,
detections=detections,
labels=labels
)
out.write(annotated_frame)
cap.release()
out.release()
if os.path.exists(temp_frame_path):
os.remove(temp_frame_path)
return output_path
except Exception as e:
return f"Error processing video: {str(e)}"
def process_input(input_file, prompt=TEXT_PROMPT):
"""Process either image or video input"""
if input_file is None:
return "Please provide an input file"
file_path = input_file.name
extension = os.path.splitext(file_path)[1].lower()
if extension in ['.jpg', '.jpeg', '.png']:
return process_image(file_path, prompt)
elif extension in ['.mp4', '.avi', '.mov']:
return process_video(file_path, prompt)
else:
return "Unsupported file format. Please use jpg/jpeg/png for images or mp4/avi/mov for videos."
# Create Gradio interface
demo = gr.Interface(
fn=process_input,
inputs=[
gr.File(
label="Upload Image/Video",
file_types=["image", "video"]
),
gr.Textbox(
label="Detection Prompt",
value=TEXT_PROMPT,
lines=2
)
],
outputs=gr.Image(label="Detection Result"),
title="DINO-X Object Detection",
description="Upload an image or video to detect objects using DINO-X. You can modify the detection prompt to specify what objects to look for.",
examples=[
["assets/demo.png", TEXT_PROMPT],
["assets/demo.mp4", VID_PROMPT]
],
cache_examples=True
)
if __name__ == "__main__":
demo.launch()