import logging import os from typing import Tuple, List, Optional from pathlib import Path import shutil import tempfile import numpy as np import cv2 import gradio as gr from PIL import Image from transformers import pipeline from transformers.image_utils import load_image import tqdm # Configuration constants CHECKPOINTS = [ "ustc-community/dfine_m_obj365", "ustc-community/dfine_n_coco", "ustc-community/dfine_s_coco", "ustc-community/dfine_m_coco", "ustc-community/dfine_l_coco", "ustc-community/dfine_x_coco", "ustc-community/dfine_s_obj365", "ustc-community/dfine_l_obj365", "ustc-community/dfine_x_obj365", "ustc-community/dfine_s_obj2coco", "ustc-community/dfine_m_obj2coco", "ustc-community/dfine_l_obj2coco_e25", "ustc-community/dfine_x_obj2coco", ] MAX_NUM_FRAMES = 300 DEFAULT_CHECKPOINT = CHECKPOINTS[0] DEFAULT_CONFIDENCE_THRESHOLD = 0.3 IMAGE_EXAMPLES = [ {"path": "./image.jpg", "use_url": False, "url": "", "label": "Local Image"}, { "path": None, "use_url": True, "url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", "label": "Flickr Image", }, ] VIDEO_EXAMPLES = [ {"path": "./video.mp4", "label": "Local Video"}, ] ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) VIDEO_OUTPUT_DIR = Path("static/videos") VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) def detect_objects( image: Optional[Image.Image], checkpoint: str, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, use_url: bool = False, url: str = "", ) -> Tuple[ Optional[Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]], gr.Markdown, ]: if use_url and url: try: input_image = load_image(url) except Exception as e: logger.error(f"Failed to load image from URL {url}: {str(e)}") return None, gr.Markdown( f"**Error**: Failed to load image from URL: {str(e)}", visible=True ) elif image is not None: if not isinstance(image, Image.Image): logger.error("Input image is not a PIL Image") return None, gr.Markdown("**Error**: Invalid image format.", visible=True) input_image = image else: return None, gr.Markdown( "**Error**: Please provide an image or URL.", visible=True ) try: pipe = pipeline( "object-detection", model=checkpoint, image_processor=checkpoint, device="cpu", ) except Exception as e: logger.error(f"Failed to initialize model pipeline for {checkpoint}: {str(e)}") return None, gr.Markdown( f"**Error**: Failed to load model: {str(e)}", visible=True ) results = pipe(input_image, threshold=confidence_threshold) img_width, img_height = input_image.size annotations = [] for result in results: score = result["score"] if score < confidence_threshold: continue label = f"{result['label']} ({score:.2f})" box = result["box"] # Validate and convert box to (xmin, ymin, xmax, ymax) bbox_xmin = max(0, int(box["xmin"])) bbox_ymin = max(0, int(box["ymin"])) bbox_xmax = min(img_width, int(box["xmax"])) bbox_ymax = min(img_height, int(box["ymax"])) if bbox_xmax <= bbox_xmin or bbox_ymax <= bbox_ymin: continue bounding_box = (bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax) annotations.append((bounding_box, label)) if not annotations: return (input_image, []), gr.Markdown( "**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.", visible=True, ) return (input_image, annotations), gr.Markdown(visible=False) def annotate_frame( image: Image.Image, annotations: List[Tuple[Tuple[int, int, int, int], str]] ) -> np.ndarray: image_np = np.array(image) image_bgr = image_np[:, :, ::-1].copy() # RGB to BGR for (xmin, ymin, xmax, ymax), label in annotations: cv2.rectangle(image_bgr, (xmin, ymin), (xmax, ymax), (255, 255, 255), 2) text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] cv2.rectangle( image_bgr, (xmin, ymin - text_size[1] - 4), (xmin + text_size[0], ymin), (255, 255, 255), -1, ) cv2.putText( image_bgr, label, (xmin, ymin - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, ) return image_bgr def process_video( video_path: str, checkpoint: str, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, progress: gr.Progress = gr.Progress(track_tqdm=True), ) -> Tuple[Optional[str], gr.Markdown]: if not video_path or not os.path.isfile(video_path): logger.error(f"Invalid video path: {video_path}") return None, gr.Markdown( "**Error**: Please provide a valid video file.", visible=True ) ext = os.path.splitext(video_path)[1].lower() if ext not in ALLOWED_VIDEO_EXTENSIONS: logger.error(f"Unsupported video format: {ext}") return None, gr.Markdown( f"**Error**: Unsupported video format. Use MP4, AVI, or MOV.", visible=True ) try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): logger.error(f"Failed to open video: {video_path}") return None, gr.Markdown( "**Error**: Failed to open video file.", visible=True ) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) fps = cap.get(cv2.CAP_PROP_FPS) num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Use H.264 codec for browser compatibility fourcc = cv2.VideoWriter_fourcc(*"H264") temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height)) if not writer.isOpened(): logger.error("Failed to initialize video writer") cap.release() temp_file.close() os.unlink(temp_file.name) return None, gr.Markdown( "**Error**: Failed to initialize video writer.", visible=True ) frame_count = 0 for _ in tqdm.tqdm( range(min(MAX_NUM_FRAMES, num_frames)), desc="Processing video" ): ok, frame = cap.read() if not ok: break rgb_frame = frame[:, :, ::-1] # BGR to RGB pil_image = Image.fromarray(rgb_frame) (annotated_image, annotations), _ = detect_objects( pil_image, checkpoint, confidence_threshold, use_url=False, url="" ) if annotated_image is None: continue annotated_frame = annotate_frame(annotated_image, annotations) writer.write(annotated_frame) frame_count += 1 writer.release() cap.release() if frame_count == 0: logger.warning("No valid frames processed in video") temp_file.close() os.unlink(temp_file.name) return None, gr.Markdown( "**Warning**: No valid frames processed. Try a different video or threshold.", visible=True, ) temp_file.close() # Copy to persistent directory for Gradio access output_filename = f"output_{os.path.basename(temp_file.name)}" output_path = VIDEO_OUTPUT_DIR / output_filename shutil.copy(temp_file.name, output_path) os.unlink(temp_file.name) # Remove temporary file logger.info(f"Video saved to {output_path}") return str(output_path), gr.Markdown(visible=False) except Exception as e: logger.error(f"Video processing failed: {str(e)}") if "temp_file" in locals(): temp_file.close() if os.path.exists(temp_file.name): os.unlink(temp_file.name) return None, gr.Markdown( f"**Error**: Video processing failed: {str(e)}", visible=True ) def create_image_inputs() -> List[gr.components.Component]: return [ gr.Image( label="Upload Image", type="pil", sources=["upload", "webcam"], interactive=True, elem_classes="input-component", ), gr.Checkbox(label="Use Image URL Instead", value=False), gr.Textbox( label="Image URL", placeholder="https://example.com/image.jpg", visible=False, elem_classes="input-component", ), gr.Dropdown( choices=CHECKPOINTS, label="Select Model Checkpoint", value=DEFAULT_CHECKPOINT, elem_classes="input-component", ), gr.Slider( minimum=0.1, maximum=1.0, value=DEFAULT_CONFIDENCE_THRESHOLD, step=0.1, label="Confidence Threshold", elem_classes="input-component", ), ] def create_video_inputs() -> List[gr.components.Component]: return [ gr.Video( label="Upload Video", sources=["upload"], interactive=True, format="mp4", # Ensure MP4 format elem_classes="input-component", ), gr.Dropdown( choices=CHECKPOINTS, label="Select Model Checkpoint", value=DEFAULT_CHECKPOINT, elem_classes="input-component", ), gr.Slider( minimum=0.1, maximum=1.0, value=DEFAULT_CONFIDENCE_THRESHOLD, step=0.1, label="Confidence Threshold", elem_classes="input-component", ), ] def create_button_row(is_image: bool) -> List[gr.Button]: prefix = "Image" if is_image else "Video" return [ gr.Button( f"{prefix} Detect Objects", variant="primary", elem_classes="action-button" ), gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"), ] # Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Real-Time Object Detection Demo Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video, provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! """, elem_classes="header-text", ) with gr.Tabs(): with gr.Tab("Image"): with gr.Row(): with gr.Column(scale=1, min_width=300): with gr.Group(): ( image_input, use_url, url_input, image_checkpoint, image_confidence_threshold, ) = create_image_inputs() image_detect_button, image_clear_button = create_button_row( is_image=True ) with gr.Column(scale=2): image_output = gr.AnnotatedImage( label="Detection Results", show_label=True, color_map=None, elem_classes="output-component", ) image_error_message = gr.Markdown( visible=False, elem_classes="error-text" ) gr.Examples( examples=[ [ example["path"], example["use_url"], example["url"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD, ] for example in IMAGE_EXAMPLES ], inputs=[ image_input, use_url, url_input, image_checkpoint, image_confidence_threshold, ], outputs=[image_output, image_error_message], fn=detect_objects, cache_examples=False, label="Select an image example to populate inputs", ) with gr.Tab("Video"): gr.Markdown( f"The input video will be truncated to {MAX_NUM_FRAMES} frames." ) with gr.Row(): with gr.Column(scale=1, min_width=300): with gr.Group(): video_input, video_checkpoint, video_confidence_threshold = ( create_video_inputs() ) video_detect_button, video_clear_button = create_button_row( is_image=False ) with gr.Column(scale=2): video_output = gr.Video( label="Detection Results", format="mp4", # Explicit MP4 format elem_classes="output-component", ) video_error_message = gr.Markdown( visible=False, elem_classes="error-text" ) gr.Examples( examples=[ [example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD] for example in VIDEO_EXAMPLES ], inputs=[video_input, video_checkpoint, video_confidence_threshold], outputs=[video_output, video_error_message], fn=process_video, cache_examples=False, label="Select a video example to populate inputs", ) # Dynamic visibility for URL input use_url.change( fn=lambda x: gr.update(visible=x), inputs=use_url, outputs=url_input, ) # Image clear button image_clear_button.click( fn=lambda: ( None, False, "", DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD, None, gr.Markdown(visible=False), ), outputs=[ image_input, use_url, url_input, image_checkpoint, image_confidence_threshold, image_output, image_error_message, ], ) # Video clear button video_clear_button.click( fn=lambda: ( None, DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD, None, gr.Markdown(visible=False), ), outputs=[ video_input, video_checkpoint, video_confidence_threshold, video_output, video_error_message, ], ) # Image detect button image_detect_button.click( fn=detect_objects, inputs=[ image_input, image_checkpoint, image_confidence_threshold, use_url, url_input, ], outputs=[image_output, image_error_message], ) # Video detect button video_detect_button.click( fn=process_video, inputs=[video_input, video_checkpoint, video_confidence_threshold], outputs=[video_output, video_error_message], ) if __name__ == "__main__": demo.queue(max_size=20).launch()