Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
from typing import Tuple, List, Optional | |
from pathlib import Path | |
import shutil | |
import tempfile | |
import numpy as np | |
import cv2 | |
import gradio as gr | |
from PIL import Image | |
from transformers import pipeline | |
from transformers.image_utils import load_image | |
import tqdm | |
# Configuration constants | |
CHECKPOINTS = [ | |
"ustc-community/dfine_m_obj365", | |
"ustc-community/dfine_n_coco", | |
"ustc-community/dfine_s_coco", | |
"ustc-community/dfine_m_coco", | |
"ustc-community/dfine_l_coco", | |
"ustc-community/dfine_x_coco", | |
"ustc-community/dfine_s_obj365", | |
"ustc-community/dfine_l_obj365", | |
"ustc-community/dfine_x_obj365", | |
"ustc-community/dfine_s_obj2coco", | |
"ustc-community/dfine_m_obj2coco", | |
"ustc-community/dfine_l_obj2coco_e25", | |
"ustc-community/dfine_x_obj2coco", | |
] | |
MAX_NUM_FRAMES = 300 | |
DEFAULT_CHECKPOINT = CHECKPOINTS[0] | |
DEFAULT_CONFIDENCE_THRESHOLD = 0.3 | |
IMAGE_EXAMPLES = [ | |
{"path": "./image.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
{ | |
"path": None, | |
"use_url": True, | |
"url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", | |
"label": "Flickr Image", | |
}, | |
] | |
VIDEO_EXAMPLES = [ | |
{"path": "./video.mp4", "label": "Local Video"}, | |
] | |
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
VIDEO_OUTPUT_DIR = Path("static/videos") | |
VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
def detect_objects( | |
image: Optional[Image.Image], | |
checkpoint: str, | |
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
use_url: bool = False, | |
url: str = "", | |
) -> Tuple[ | |
Optional[Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]], | |
gr.Markdown, | |
]: | |
if use_url and url: | |
try: | |
input_image = load_image(url) | |
except Exception as e: | |
logger.error(f"Failed to load image from URL {url}: {str(e)}") | |
return None, gr.Markdown( | |
f"**Error**: Failed to load image from URL: {str(e)}", visible=True | |
) | |
elif image is not None: | |
if not isinstance(image, Image.Image): | |
logger.error("Input image is not a PIL Image") | |
return None, gr.Markdown("**Error**: Invalid image format.", visible=True) | |
input_image = image | |
else: | |
return None, gr.Markdown( | |
"**Error**: Please provide an image or URL.", visible=True | |
) | |
try: | |
pipe = pipeline( | |
"object-detection", | |
model=checkpoint, | |
image_processor=checkpoint, | |
device="cpu", | |
) | |
except Exception as e: | |
logger.error(f"Failed to initialize model pipeline for {checkpoint}: {str(e)}") | |
return None, gr.Markdown( | |
f"**Error**: Failed to load model: {str(e)}", visible=True | |
) | |
results = pipe(input_image, threshold=confidence_threshold) | |
img_width, img_height = input_image.size | |
annotations = [] | |
for result in results: | |
score = result["score"] | |
if score < confidence_threshold: | |
continue | |
label = f"{result['label']} ({score:.2f})" | |
box = result["box"] | |
# Validate and convert box to (xmin, ymin, xmax, ymax) | |
bbox_xmin = max(0, int(box["xmin"])) | |
bbox_ymin = max(0, int(box["ymin"])) | |
bbox_xmax = min(img_width, int(box["xmax"])) | |
bbox_ymax = min(img_height, int(box["ymax"])) | |
if bbox_xmax <= bbox_xmin or bbox_ymax <= bbox_ymin: | |
continue | |
bounding_box = (bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax) | |
annotations.append((bounding_box, label)) | |
if not annotations: | |
return (input_image, []), gr.Markdown( | |
"**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.", | |
visible=True, | |
) | |
return (input_image, annotations), gr.Markdown(visible=False) | |
def annotate_frame( | |
image: Image.Image, annotations: List[Tuple[Tuple[int, int, int, int], str]] | |
) -> np.ndarray: | |
image_np = np.array(image) | |
image_bgr = image_np[:, :, ::-1].copy() # RGB to BGR | |
for (xmin, ymin, xmax, ymax), label in annotations: | |
cv2.rectangle(image_bgr, (xmin, ymin), (xmax, ymax), (255, 255, 255), 2) | |
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] | |
cv2.rectangle( | |
image_bgr, | |
(xmin, ymin - text_size[1] - 4), | |
(xmin + text_size[0], ymin), | |
(255, 255, 255), | |
-1, | |
) | |
cv2.putText( | |
image_bgr, | |
label, | |
(xmin, ymin - 4), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
(0, 0, 0), | |
1, | |
) | |
return image_bgr | |
def process_video( | |
video_path: str, | |
checkpoint: str, | |
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), | |
) -> Tuple[Optional[str], gr.Markdown]: | |
if not video_path or not os.path.isfile(video_path): | |
logger.error(f"Invalid video path: {video_path}") | |
return None, gr.Markdown( | |
"**Error**: Please provide a valid video file.", visible=True | |
) | |
ext = os.path.splitext(video_path)[1].lower() | |
if ext not in ALLOWED_VIDEO_EXTENSIONS: | |
logger.error(f"Unsupported video format: {ext}") | |
return None, gr.Markdown( | |
f"**Error**: Unsupported video format. Use MP4, AVI, or MOV.", visible=True | |
) | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
logger.error(f"Failed to open video: {video_path}") | |
return None, gr.Markdown( | |
"**Error**: Failed to open video file.", visible=True | |
) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Use H.264 codec for browser compatibility | |
# fourcc = cv2.VideoWriter_fourcc(*"H264") | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height)) | |
if not writer.isOpened(): | |
logger.error("Failed to initialize video writer") | |
cap.release() | |
temp_file.close() | |
os.unlink(temp_file.name) | |
return None, gr.Markdown( | |
"**Error**: Failed to initialize video writer.", visible=True | |
) | |
frame_count = 0 | |
for _ in tqdm.tqdm( | |
range(min(MAX_NUM_FRAMES, num_frames)), desc="Processing video" | |
): | |
ok, frame = cap.read() | |
if not ok: | |
break | |
rgb_frame = frame[:, :, ::-1] # BGR to RGB | |
pil_image = Image.fromarray(rgb_frame) | |
(annotated_image, annotations), _ = detect_objects( | |
pil_image, checkpoint, confidence_threshold, use_url=False, url="" | |
) | |
if annotated_image is None: | |
continue | |
annotated_frame = annotate_frame(annotated_image, annotations) | |
writer.write(annotated_frame) | |
frame_count += 1 | |
writer.release() | |
cap.release() | |
if frame_count == 0: | |
logger.warning("No valid frames processed in video") | |
temp_file.close() | |
os.unlink(temp_file.name) | |
return None, gr.Markdown( | |
"**Warning**: No valid frames processed. Try a different video or threshold.", | |
visible=True, | |
) | |
temp_file.close() | |
# Copy to persistent directory for Gradio access | |
output_filename = f"output_{os.path.basename(temp_file.name)}" | |
output_path = VIDEO_OUTPUT_DIR / output_filename | |
shutil.copy(temp_file.name, output_path) | |
os.unlink(temp_file.name) # Remove temporary file | |
logger.info(f"Video saved to {output_path}") | |
return str(output_path), gr.Markdown(visible=False) | |
except Exception as e: | |
logger.error(f"Video processing failed: {str(e)}") | |
if "temp_file" in locals(): | |
temp_file.close() | |
if os.path.exists(temp_file.name): | |
os.unlink(temp_file.name) | |
return None, gr.Markdown( | |
f"**Error**: Video processing failed: {str(e)}", visible=True | |
) | |
def create_image_inputs() -> List[gr.components.Component]: | |
return [ | |
gr.Image( | |
label="Upload Image", | |
type="pil", | |
sources=["upload", "webcam"], | |
interactive=True, | |
elem_classes="input-component", | |
), | |
gr.Checkbox(label="Use Image URL Instead", value=False), | |
gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
visible=False, | |
elem_classes="input-component", | |
), | |
gr.Dropdown( | |
choices=CHECKPOINTS, | |
label="Select Model Checkpoint", | |
value=DEFAULT_CHECKPOINT, | |
elem_classes="input-component", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=DEFAULT_CONFIDENCE_THRESHOLD, | |
step=0.1, | |
label="Confidence Threshold", | |
elem_classes="input-component", | |
), | |
] | |
def create_video_inputs() -> List[gr.components.Component]: | |
return [ | |
gr.Video( | |
label="Upload Video", | |
sources=["upload"], | |
interactive=True, | |
format="mp4", # Ensure MP4 format | |
elem_classes="input-component", | |
), | |
gr.Dropdown( | |
choices=CHECKPOINTS, | |
label="Select Model Checkpoint", | |
value=DEFAULT_CHECKPOINT, | |
elem_classes="input-component", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=DEFAULT_CONFIDENCE_THRESHOLD, | |
step=0.1, | |
label="Confidence Threshold", | |
elem_classes="input-component", | |
), | |
] | |
def create_button_row(is_image: bool) -> List[gr.Button]: | |
prefix = "Image" if is_image else "Video" | |
return [ | |
gr.Button( | |
f"{prefix} Detect Objects", variant="primary", elem_classes="action-button" | |
), | |
gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"), | |
] | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Real-Time Object Detection Demo | |
Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video, | |
provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! | |
""", | |
elem_classes="header-text", | |
) | |
with gr.Tabs(): | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
with gr.Group(): | |
( | |
image_input, | |
use_url, | |
url_input, | |
image_checkpoint, | |
image_confidence_threshold, | |
) = create_image_inputs() | |
image_detect_button, image_clear_button = create_button_row( | |
is_image=True | |
) | |
with gr.Column(scale=2): | |
image_output = gr.AnnotatedImage( | |
label="Detection Results", | |
show_label=True, | |
color_map=None, | |
elem_classes="output-component", | |
) | |
image_error_message = gr.Markdown( | |
visible=False, elem_classes="error-text" | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
example["path"], | |
example["use_url"], | |
example["url"], | |
DEFAULT_CHECKPOINT, | |
DEFAULT_CONFIDENCE_THRESHOLD, | |
] | |
for example in IMAGE_EXAMPLES | |
], | |
inputs=[ | |
image_input, | |
use_url, | |
url_input, | |
image_checkpoint, | |
image_confidence_threshold, | |
], | |
outputs=[image_output, image_error_message], | |
fn=detect_objects, | |
cache_examples=False, | |
label="Select an image example to populate inputs", | |
) | |
with gr.Tab("Video"): | |
gr.Markdown( | |
f"The input video will be truncated to {MAX_NUM_FRAMES} frames." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
with gr.Group(): | |
video_input, video_checkpoint, video_confidence_threshold = ( | |
create_video_inputs() | |
) | |
video_detect_button, video_clear_button = create_button_row( | |
is_image=False | |
) | |
with gr.Column(scale=2): | |
video_output = gr.Video( | |
label="Detection Results", | |
format="mp4", # Explicit MP4 format | |
elem_classes="output-component", | |
) | |
video_error_message = gr.Markdown( | |
visible=False, elem_classes="error-text" | |
) | |
gr.Examples( | |
examples=[ | |
[example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD] | |
for example in VIDEO_EXAMPLES | |
], | |
inputs=[video_input, video_checkpoint, video_confidence_threshold], | |
outputs=[video_output, video_error_message], | |
fn=process_video, | |
cache_examples=False, | |
label="Select a video example to populate inputs", | |
) | |
# Dynamic visibility for URL input | |
use_url.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=use_url, | |
outputs=url_input, | |
) | |
# Image clear button | |
image_clear_button.click( | |
fn=lambda: ( | |
None, | |
False, | |
"", | |
DEFAULT_CHECKPOINT, | |
DEFAULT_CONFIDENCE_THRESHOLD, | |
None, | |
gr.Markdown(visible=False), | |
), | |
outputs=[ | |
image_input, | |
use_url, | |
url_input, | |
image_checkpoint, | |
image_confidence_threshold, | |
image_output, | |
image_error_message, | |
], | |
) | |
# Video clear button | |
video_clear_button.click( | |
fn=lambda: ( | |
None, | |
DEFAULT_CHECKPOINT, | |
DEFAULT_CONFIDENCE_THRESHOLD, | |
None, | |
gr.Markdown(visible=False), | |
), | |
outputs=[ | |
video_input, | |
video_checkpoint, | |
video_confidence_threshold, | |
video_output, | |
video_error_message, | |
], | |
) | |
# Image detect button | |
image_detect_button.click( | |
fn=detect_objects, | |
inputs=[ | |
image_input, | |
image_checkpoint, | |
image_confidence_threshold, | |
use_url, | |
url_input, | |
], | |
outputs=[image_output, image_error_message], | |
) | |
# Video detect button | |
video_detect_button.click( | |
fn=process_video, | |
inputs=[video_input, video_checkpoint, video_confidence_threshold], | |
outputs=[video_output, video_error_message], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |