Dino-X-API-Demo / app.py
minar09's picture
Create app.py
6c6cd1e verified
raw
history blame
7.59 kB
import os
import cv2
import gradio as gr
import numpy as np
import supervision as sv
from pathlib import Path
from dds_cloudapi_sdk import Config, Client, TextPrompt
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
from dds_cloudapi_sdk.tasks.detection import DetectionTask
from dds_cloudapi_sdk.tasks.types import DetectionTarget
# Constants
API_TOKEN = "361d32fa5ce22649133660c65cfcaf22"
TEXT_PROMPT = "wheel . eye . helmet . mouse . mouth . vehicle . steering wheel . ear . nose"
TEMP_DIR = "./temp"
OUTPUT_DIR = "./outputs"
# Ensure directories exist
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
def initialize_dino_client():
"""Initialize the DINO-X client"""
config = Config(API_TOKEN)
return Client(config)
def get_class_mappings(text_prompt):
"""Create class name to ID mappings"""
classes = [x.strip().lower() for x in text_prompt.split('.') if x]
class_name_to_id = {name: id for id, name in enumerate(classes)}
return classes, class_name_to_id
def process_predictions(predictions, class_name_to_id):
"""Process DINO-X predictions into detection format"""
boxes = []
masks = []
confidences = []
class_names = []
class_ids = []
for obj in predictions:
boxes.append(obj.bbox)
if hasattr(obj, 'mask') and obj.mask:
masks.append(DetectionTask.rle2mask(
DetectionTask.string2rle(obj.mask.counts),
obj.mask.size
))
cls_name = obj.category.lower().strip()
class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name])
confidences.append(obj.score)
return {
'boxes': np.array(boxes),
'masks': np.array(masks) if masks else None,
'class_ids': np.array(class_ids),
'class_names': class_names,
'confidences': confidences
}
def process_image(image_path, prompt=TEXT_PROMPT):
"""Process a single image with DINO-X"""
try:
client = initialize_dino_client()
_, class_name_to_id = get_class_mappings(prompt)
# Upload and process image
image_url = client.upload_file(image_path)
task = DinoxTask(
image_url=image_url,
prompts=[TextPrompt(text=prompt)],
bbox_threshold=0.25,
targets=[DetectionTarget.BBox, DetectionTarget.Mask]
)
client.run_task(task)
# Process predictions
results = process_predictions(task.result.objects, class_name_to_id)
# Annotate image
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=results['boxes'],
mask=results['masks'].astype(bool) if results['masks'] is not None else None,
class_id=results['class_ids']
)
labels = [
f"{name} {conf:.2f}"
for name, conf in zip(results['class_names'], results['confidences'])
]
# Apply annotations
annotator = sv.BoxAnnotator()
annotated_frame = annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(
scene=annotated_frame,
detections=detections,
labels=labels
)
if results['masks'] is not None:
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(
scene=annotated_frame,
detections=detections
)
output_path = os.path.join(OUTPUT_DIR, "result.jpg")
cv2.imwrite(output_path, annotated_frame)
return output_path
except Exception as e:
return f"Error processing image: {str(e)}"
def process_video(video_path, prompt=TEXT_PROMPT):
"""Process a video with DINO-X"""
try:
client = initialize_dino_client()
_, class_name_to_id = get_class_mappings(prompt)
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
output_path = os.path.join(OUTPUT_DIR, "result.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
temp_frame_path = os.path.join(TEMP_DIR, "temp_frame.jpg")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % 3 != 0: # Process every 3rd frame for speed
continue
cv2.imwrite(temp_frame_path, frame)
image_url = client.upload_file(temp_frame_path)
task = DinoxTask(
image_url=image_url,
prompts=[TextPrompt(text=prompt)],
bbox_threshold=0.25
)
client.run_task(task)
results = process_predictions(task.result.objects, class_name_to_id)
detections = sv.Detections(
xyxy=results['boxes'],
class_id=results['class_ids']
)
labels = [
f"{name} {conf:.2f}"
for name, conf in zip(results['class_names'], results['confidences'])
]
annotator = sv.BoxAnnotator()
annotated_frame = annotator.annotate(scene=frame.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(
scene=annotated_frame,
detections=detections,
labels=labels
)
out.write(annotated_frame)
cap.release()
out.release()
if os.path.exists(temp_frame_path):
os.remove(temp_frame_path)
return output_path
except Exception as e:
return f"Error processing video: {str(e)}"
def process_input(input_file, prompt=TEXT_PROMPT):
"""Process either image or video input"""
if input_file is None:
return "Please provide an input file"
file_path = input_file.name
extension = os.path.splitext(file_path)[1].lower()
if extension in ['.jpg', '.jpeg', '.png']:
return process_image(file_path, prompt)
elif extension in ['.mp4', '.avi', '.mov']:
return process_video(file_path, prompt)
else:
return "Unsupported file format. Please use jpg/jpeg/png for images or mp4/avi/mov for videos."
# Create Gradio interface
demo = gr.Interface(
fn=process_input,
inputs=[
gr.File(
label="Upload Image/Video",
file_types=["image", "video"]
),
gr.Textbox(
label="Detection Prompt",
value=TEXT_PROMPT,
lines=2
)
],
outputs=gr.Image(label="Detection Result"),
title="DINO-X Object Detection",
description="Upload an image or video to detect objects using DINO-X. You can modify the detection prompt to specify what objects to look for.",
examples=[
["assets/demo.png", TEXT_PROMPT],
["assets/demo.mp4", TEXT_PROMPT]
],
cache_examples=True
)
if __name__ == "__main__":
demo.launch()