Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import cv2 | |
import tqdm | |
import shutil | |
import tempfile | |
import logging | |
import supervision as sv | |
import torch | |
import spaces | |
import gradio as gr | |
from pathlib import Path | |
from functools import lru_cache | |
from typing import List, Optional, Tuple | |
from PIL import Image | |
from transformers import AutoModelForObjectDetection, AutoImageProcessor | |
from transformers.image_utils import load_image | |
# Configuration constants | |
CHECKPOINTS = [ | |
"ustc-community/dfine_m_obj2coco", | |
"ustc-community/dfine_m_obj365", | |
"ustc-community/dfine_n_coco", | |
"ustc-community/dfine_s_coco", | |
"ustc-community/dfine_m_coco", | |
"ustc-community/dfine_l_coco", | |
"ustc-community/dfine_x_coco", | |
"ustc-community/dfine_s_obj365", | |
"ustc-community/dfine_l_obj365", | |
"ustc-community/dfine_x_obj365", | |
"ustc-community/dfine_s_obj2coco", | |
"ustc-community/dfine_l_obj2coco_e25", | |
"ustc-community/dfine_x_obj2coco", | |
] | |
DEFAULT_CHECKPOINT = CHECKPOINTS[0] | |
DEFAULT_CONFIDENCE_THRESHOLD = 0.3 | |
TORCH_DTYPE = torch.float32 | |
# Image | |
IMAGE_EXAMPLES = [ | |
{"path": "./examples/images/crossroad.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
{ | |
"path": None, | |
"use_url": True, | |
"url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", | |
"label": "Flickr Image", | |
}, | |
] | |
# Video | |
MAX_NUM_FRAMES = 500 | |
BATCH_SIZE = 4 | |
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} | |
VIDEO_OUTPUT_DIR = Path("static/videos") | |
VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
VIDEO_EXAMPLES = [ | |
{"path": "./examples/videos/traffic.mp4", "label": "Local Video"}, | |
{"path": "./examples/videos/fast_and_furious.mp4", "label": "Local Video"}, | |
{"path": "./examples/videos/break_dance.mp4", "label": "Local Video"}, | |
] | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
def get_model_and_image_processor(checkpoint: str, device: str = "cpu"): | |
model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE).to(device) | |
image_processor = AutoImageProcessor.from_pretrained(checkpoint) | |
return model, image_processor | |
def detect_objects( | |
checkpoint: str, | |
images: Optional[List[Image.Image]] = None, | |
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
target_sizes: Optional[List[Tuple[int, int]]] = None, | |
): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, image_processor = get_model_and_image_processor(checkpoint, device=device) | |
# preprocess images | |
inputs = image_processor(images=images, return_tensors="pt") | |
inputs = inputs.to(device).to(TORCH_DTYPE) | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# postprocess outputs | |
if not target_sizes: | |
target_sizes = [(image.height, image.width) for image in images] | |
results = image_processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=confidence_threshold | |
) | |
return results, model.config.id2label | |
def process_image( | |
checkpoint: str = DEFAULT_CHECKPOINT, | |
image: Optional[Image.Image] = None, | |
url: Optional[str] = None, | |
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
): | |
if (image is None) ^ bool(url): | |
raise ValueError(f"Either image or url must be provided, but not both.") | |
if url: | |
image = load_image(url) | |
results, id2label = detect_objects( | |
checkpoint=checkpoint, | |
images=[image], | |
confidence_threshold=confidence_threshold, | |
) | |
result = results[0] # first image in batch (we have batch size 1) | |
annotations = [] | |
for label, score, box in zip(result["labels"], result["scores"], result["boxes"]): | |
text_label = id2label[label.item()] | |
formatted_label = f"{text_label} ({score:.2f})" | |
x_min, y_min, x_max, y_max = box.cpu().numpy().round().astype(int) | |
x_min = max(0, x_min) | |
y_min = max(0, y_min) | |
x_max = min(image.width - 1, x_max) | |
y_max = min(image.height - 1, y_max) | |
annotations.append(((x_min, y_min, x_max, y_max), formatted_label)) | |
return (image, annotations) | |
def get_target_size(image_height, image_width, max_size: int): | |
if image_height < max_size and image_width < max_size: | |
return image_width, image_height | |
if image_height > image_width: | |
new_height = max_size | |
new_width = int(image_width * max_size / image_height) | |
else: | |
new_width = max_size | |
new_height = int(image_height * max_size / image_width) | |
return new_width, new_height | |
def process_video( | |
video_path: str, | |
checkpoint: str, | |
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), | |
) -> str: | |
if not video_path or not os.path.isfile(video_path): | |
raise ValueError(f"Invalid video path: {video_path}") | |
ext = os.path.splitext(video_path)[1].lower() | |
if ext not in ALLOWED_VIDEO_EXTENSIONS: | |
raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}") | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise ValueError(f"Failed to open video: {video_path}") | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
process_each_frame = fps // 25 | |
target_fps = fps / process_each_frame | |
target_width, target_height = get_target_size(height, width, 1080) | |
# Use H.264 codec for browser compatibility | |
fourcc = cv2.VideoWriter_fourcc(*"MJPG") | |
temp_file = tempfile.NamedTemporaryFile(suffix=".avi", delete=False) | |
writer = cv2.VideoWriter(temp_file.name, fourcc, target_fps, (target_width, target_height)) | |
box_annotator = sv.BoxAnnotator(thickness=1) | |
label_annotator = sv.LabelAnnotator(text_scale=0.5) | |
if not writer.isOpened(): | |
cap.release() | |
temp_file.close() | |
os.unlink(temp_file.name) | |
raise ValueError("Failed to initialize video writer") | |
frames_to_process = int(min(MAX_NUM_FRAMES * process_each_frame, num_frames)) | |
batch = [] | |
for i in tqdm.tqdm(range(frames_to_process), desc="Processing video"): | |
ok, frame = cap.read() | |
if not ok: | |
break | |
if not i % process_each_frame == 0: | |
continue | |
if len(batch) < BATCH_SIZE: | |
frame = frame[:, :, ::-1].copy() # BGR to RGB | |
batch.append(frame) | |
continue | |
results, id2label = detect_objects( | |
images=[Image.fromarray(frame) for frame in batch], | |
checkpoint=checkpoint, | |
confidence_threshold=confidence_threshold, | |
target_sizes=[(target_height, target_width)] * len(batch), | |
) | |
for frame, result in zip(batch, results): | |
frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA) | |
detections = sv.Detections.from_transformers(result, id2label=id2label) | |
detections = detections.with_nms(threshold=0.95, class_agnostic=True) | |
annotated_frame = box_annotator.annotate(scene=frame, detections=detections) | |
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections) | |
writer.write(cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)) | |
batch = [] | |
writer.release() | |
cap.release() | |
temp_file.close() | |
# Copy to persistent directory for Gradio access | |
output_filename = f"output_{os.path.basename(temp_file.name)}" | |
output_path = VIDEO_OUTPUT_DIR / output_filename | |
shutil.copy(temp_file.name, output_path) | |
os.unlink(temp_file.name) # Remove temporary file | |
logger.info(f"Video saved to {output_path}") | |
return str(output_path) | |
def create_image_inputs() -> List[gr.components.Component]: | |
return [ | |
gr.Image( | |
label="Upload Image", | |
type="pil", | |
sources=["upload", "webcam"], | |
interactive=True, | |
elem_classes="input-component", | |
), | |
gr.Checkbox(label="Use Image URL Instead", value=False), | |
gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
visible=False, | |
elem_classes="input-component", | |
), | |
gr.Dropdown( | |
choices=CHECKPOINTS, | |
label="Select Model Checkpoint", | |
value=DEFAULT_CHECKPOINT, | |
elem_classes="input-component", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=DEFAULT_CONFIDENCE_THRESHOLD, | |
step=0.1, | |
label="Confidence Threshold", | |
elem_classes="input-component", | |
), | |
] | |
def create_video_inputs() -> List[gr.components.Component]: | |
return [ | |
gr.Video( | |
label="Upload Video", | |
sources=["upload"], | |
interactive=True, | |
format="mp4", # Ensure MP4 format | |
elem_classes="input-component", | |
), | |
gr.Dropdown( | |
choices=CHECKPOINTS, | |
label="Select Model Checkpoint", | |
value=DEFAULT_CHECKPOINT, | |
elem_classes="input-component", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=DEFAULT_CONFIDENCE_THRESHOLD, | |
step=0.1, | |
label="Confidence Threshold", | |
elem_classes="input-component", | |
), | |
] | |
def create_button_row(is_image: bool) -> List[gr.Button]: | |
prefix = "Image" if is_image else "Video" | |
return [ | |
gr.Button( | |
f"{prefix} Detect Objects", variant="primary", elem_classes="action-button" | |
), | |
gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"), | |
] | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Real-Time Object Detection Demo | |
Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video, | |
provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! | |
""", | |
elem_classes="header-text", | |
) | |
with gr.Tabs(): | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
with gr.Group(): | |
( | |
image_input, | |
use_url, | |
url_input, | |
image_model_checkpoint, | |
image_confidence_threshold, | |
) = create_image_inputs() | |
image_detect_button, image_clear_button = create_button_row( | |
is_image=True | |
) | |
with gr.Column(scale=2): | |
image_output = gr.AnnotatedImage( | |
label="Detection Results", | |
show_label=True, | |
color_map=None, | |
elem_classes="output-component", | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
example["path"], | |
example["use_url"], | |
example["url"], | |
DEFAULT_CHECKPOINT, | |
DEFAULT_CONFIDENCE_THRESHOLD, | |
] | |
for example in IMAGE_EXAMPLES | |
], | |
inputs=[ | |
image_input, | |
use_url, | |
url_input, | |
image_model_checkpoint, | |
image_confidence_threshold, | |
], | |
outputs=[image_output], | |
fn=process_image, | |
cache_examples=False, | |
label="Select an image example to populate inputs", | |
) | |
with gr.Tab("Video"): | |
gr.Markdown( | |
f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
with gr.Group(): | |
video_input, video_checkpoint, video_confidence_threshold = ( | |
create_video_inputs() | |
) | |
video_detect_button, video_clear_button = create_button_row( | |
is_image=False | |
) | |
with gr.Column(scale=2): | |
video_output = gr.Video( | |
label="Detection Results", | |
format="mp4", # Explicit MP4 format | |
elem_classes="output-component", | |
) | |
gr.Examples( | |
examples=[ | |
[example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD] | |
for example in VIDEO_EXAMPLES | |
], | |
inputs=[video_input, video_checkpoint, video_confidence_threshold], | |
outputs=[video_output], | |
fn=process_video, | |
cache_examples=False, | |
label="Select a video example to populate inputs", | |
) | |
# Dynamic visibility for URL input | |
use_url.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=use_url, | |
outputs=url_input, | |
) | |
# Image clear button | |
image_clear_button.click( | |
fn=lambda: ( | |
None, | |
False, | |
"", | |
DEFAULT_CHECKPOINT, | |
DEFAULT_CONFIDENCE_THRESHOLD, | |
None, | |
), | |
outputs=[ | |
image_input, | |
use_url, | |
url_input, | |
image_model_checkpoint, | |
image_confidence_threshold, | |
image_output, | |
], | |
) | |
# Video clear button | |
video_clear_button.click( | |
fn=lambda: ( | |
None, | |
DEFAULT_CHECKPOINT, | |
DEFAULT_CONFIDENCE_THRESHOLD, | |
None, | |
), | |
outputs=[ | |
video_input, | |
video_checkpoint, | |
video_confidence_threshold, | |
video_output, | |
], | |
) | |
# Image detect button | |
image_detect_button.click( | |
fn=process_image, | |
inputs=[ | |
image_model_checkpoint, | |
image_input, | |
url_input, | |
image_confidence_threshold, | |
], | |
outputs=[image_output], | |
) | |
# Video detect button | |
video_detect_button.click( | |
fn=process_video, | |
inputs=[video_input, video_checkpoint, video_confidence_threshold], | |
outputs=[video_output], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |