ariG23498's picture
ariG23498 HF Staff
adding video logic
dff4c96
raw
history blame
16.1 kB
import logging
import os
from typing import Tuple, List, Optional
from pathlib import Path
import shutil
import tempfile
import numpy as np
import cv2
import gradio as gr
from PIL import Image
from transformers import pipeline
from transformers.image_utils import load_image
import tqdm
# Configuration constants
CHECKPOINTS = [
"ustc-community/dfine_m_obj365",
"ustc-community/dfine_n_coco",
"ustc-community/dfine_s_coco",
"ustc-community/dfine_m_coco",
"ustc-community/dfine_l_coco",
"ustc-community/dfine_x_coco",
"ustc-community/dfine_s_obj365",
"ustc-community/dfine_l_obj365",
"ustc-community/dfine_x_obj365",
"ustc-community/dfine_s_obj2coco",
"ustc-community/dfine_m_obj2coco",
"ustc-community/dfine_l_obj2coco_e25",
"ustc-community/dfine_x_obj2coco",
]
MAX_NUM_FRAMES = 300
DEFAULT_CHECKPOINT = CHECKPOINTS[0]
DEFAULT_CONFIDENCE_THRESHOLD = 0.3
IMAGE_EXAMPLES = [
{"path": "./image.jpg", "use_url": False, "url": "", "label": "Local Image"},
{
"path": None,
"use_url": True,
"url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg",
"label": "Flickr Image",
},
]
VIDEO_EXAMPLES = [
{"path": "./video.mp4", "label": "Local Video"},
]
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"}
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
VIDEO_OUTPUT_DIR = Path("static/videos")
VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def detect_objects(
image: Optional[Image.Image],
checkpoint: str,
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
use_url: bool = False,
url: str = "",
) -> Tuple[
Optional[Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]],
gr.Markdown,
]:
if use_url and url:
try:
input_image = load_image(url)
except Exception as e:
logger.error(f"Failed to load image from URL {url}: {str(e)}")
return None, gr.Markdown(
f"**Error**: Failed to load image from URL: {str(e)}", visible=True
)
elif image is not None:
if not isinstance(image, Image.Image):
logger.error("Input image is not a PIL Image")
return None, gr.Markdown("**Error**: Invalid image format.", visible=True)
input_image = image
else:
return None, gr.Markdown(
"**Error**: Please provide an image or URL.", visible=True
)
try:
pipe = pipeline(
"object-detection",
model=checkpoint,
image_processor=checkpoint,
device="cpu",
)
except Exception as e:
logger.error(f"Failed to initialize model pipeline for {checkpoint}: {str(e)}")
return None, gr.Markdown(
f"**Error**: Failed to load model: {str(e)}", visible=True
)
results = pipe(input_image, threshold=confidence_threshold)
img_width, img_height = input_image.size
annotations = []
for result in results:
score = result["score"]
if score < confidence_threshold:
continue
label = f"{result['label']} ({score:.2f})"
box = result["box"]
# Validate and convert box to (xmin, ymin, xmax, ymax)
bbox_xmin = max(0, int(box["xmin"]))
bbox_ymin = max(0, int(box["ymin"]))
bbox_xmax = min(img_width, int(box["xmax"]))
bbox_ymax = min(img_height, int(box["ymax"]))
if bbox_xmax <= bbox_xmin or bbox_ymax <= bbox_ymin:
continue
bounding_box = (bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax)
annotations.append((bounding_box, label))
if not annotations:
return (input_image, []), gr.Markdown(
"**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.",
visible=True,
)
return (input_image, annotations), gr.Markdown(visible=False)
def annotate_frame(
image: Image.Image, annotations: List[Tuple[Tuple[int, int, int, int], str]]
) -> np.ndarray:
image_np = np.array(image)
image_bgr = image_np[:, :, ::-1].copy() # RGB to BGR
for (xmin, ymin, xmax, ymax), label in annotations:
cv2.rectangle(image_bgr, (xmin, ymin), (xmax, ymax), (255, 255, 255), 2)
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
cv2.rectangle(
image_bgr,
(xmin, ymin - text_size[1] - 4),
(xmin + text_size[0], ymin),
(255, 255, 255),
-1,
)
cv2.putText(
image_bgr,
label,
(xmin, ymin - 4),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
)
return image_bgr
def process_video(
video_path: str,
checkpoint: str,
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> Tuple[Optional[str], gr.Markdown]:
if not video_path or not os.path.isfile(video_path):
logger.error(f"Invalid video path: {video_path}")
return None, gr.Markdown(
"**Error**: Please provide a valid video file.", visible=True
)
ext = os.path.splitext(video_path)[1].lower()
if ext not in ALLOWED_VIDEO_EXTENSIONS:
logger.error(f"Unsupported video format: {ext}")
return None, gr.Markdown(
f"**Error**: Unsupported video format. Use MP4, AVI, or MOV.", visible=True
)
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Failed to open video: {video_path}")
return None, gr.Markdown(
"**Error**: Failed to open video file.", visible=True
)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
fps = cap.get(cv2.CAP_PROP_FPS)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Use H.264 codec for browser compatibility
fourcc = cv2.VideoWriter_fourcc(*"H264")
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height))
if not writer.isOpened():
logger.error("Failed to initialize video writer")
cap.release()
temp_file.close()
os.unlink(temp_file.name)
return None, gr.Markdown(
"**Error**: Failed to initialize video writer.", visible=True
)
frame_count = 0
for _ in tqdm.tqdm(
range(min(MAX_NUM_FRAMES, num_frames)), desc="Processing video"
):
ok, frame = cap.read()
if not ok:
break
rgb_frame = frame[:, :, ::-1] # BGR to RGB
pil_image = Image.fromarray(rgb_frame)
(annotated_image, annotations), _ = detect_objects(
pil_image, checkpoint, confidence_threshold, use_url=False, url=""
)
if annotated_image is None:
continue
annotated_frame = annotate_frame(annotated_image, annotations)
writer.write(annotated_frame)
frame_count += 1
writer.release()
cap.release()
if frame_count == 0:
logger.warning("No valid frames processed in video")
temp_file.close()
os.unlink(temp_file.name)
return None, gr.Markdown(
"**Warning**: No valid frames processed. Try a different video or threshold.",
visible=True,
)
temp_file.close()
# Copy to persistent directory for Gradio access
output_filename = f"output_{os.path.basename(temp_file.name)}"
output_path = VIDEO_OUTPUT_DIR / output_filename
shutil.copy(temp_file.name, output_path)
os.unlink(temp_file.name) # Remove temporary file
logger.info(f"Video saved to {output_path}")
return str(output_path), gr.Markdown(visible=False)
except Exception as e:
logger.error(f"Video processing failed: {str(e)}")
if "temp_file" in locals():
temp_file.close()
if os.path.exists(temp_file.name):
os.unlink(temp_file.name)
return None, gr.Markdown(
f"**Error**: Video processing failed: {str(e)}", visible=True
)
def create_image_inputs() -> List[gr.components.Component]:
return [
gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "webcam"],
interactive=True,
elem_classes="input-component",
),
gr.Checkbox(label="Use Image URL Instead", value=False),
gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
visible=False,
elem_classes="input-component",
),
gr.Dropdown(
choices=CHECKPOINTS,
label="Select Model Checkpoint",
value=DEFAULT_CHECKPOINT,
elem_classes="input-component",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=DEFAULT_CONFIDENCE_THRESHOLD,
step=0.1,
label="Confidence Threshold",
elem_classes="input-component",
),
]
def create_video_inputs() -> List[gr.components.Component]:
return [
gr.Video(
label="Upload Video",
sources=["upload"],
interactive=True,
format="mp4", # Ensure MP4 format
elem_classes="input-component",
),
gr.Dropdown(
choices=CHECKPOINTS,
label="Select Model Checkpoint",
value=DEFAULT_CHECKPOINT,
elem_classes="input-component",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=DEFAULT_CONFIDENCE_THRESHOLD,
step=0.1,
label="Confidence Threshold",
elem_classes="input-component",
),
]
def create_button_row(is_image: bool) -> List[gr.Button]:
prefix = "Image" if is_image else "Video"
return [
gr.Button(
f"{prefix} Detect Objects", variant="primary", elem_classes="action-button"
),
gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"),
]
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Real-Time Object Detection Demo
Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video,
provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time!
""",
elem_classes="header-text",
)
with gr.Tabs():
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=1, min_width=300):
with gr.Group():
(
image_input,
use_url,
url_input,
image_checkpoint,
image_confidence_threshold,
) = create_image_inputs()
image_detect_button, image_clear_button = create_button_row(
is_image=True
)
with gr.Column(scale=2):
image_output = gr.AnnotatedImage(
label="Detection Results",
show_label=True,
color_map=None,
elem_classes="output-component",
)
image_error_message = gr.Markdown(
visible=False, elem_classes="error-text"
)
gr.Examples(
examples=[
[
example["path"],
example["use_url"],
example["url"],
DEFAULT_CHECKPOINT,
DEFAULT_CONFIDENCE_THRESHOLD,
]
for example in IMAGE_EXAMPLES
],
inputs=[
image_input,
use_url,
url_input,
image_checkpoint,
image_confidence_threshold,
],
outputs=[image_output, image_error_message],
fn=detect_objects,
cache_examples=False,
label="Select an image example to populate inputs",
)
with gr.Tab("Video"):
gr.Markdown(
f"The input video will be truncated to {MAX_NUM_FRAMES} frames."
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
with gr.Group():
video_input, video_checkpoint, video_confidence_threshold = (
create_video_inputs()
)
video_detect_button, video_clear_button = create_button_row(
is_image=False
)
with gr.Column(scale=2):
video_output = gr.Video(
label="Detection Results",
format="mp4", # Explicit MP4 format
elem_classes="output-component",
)
video_error_message = gr.Markdown(
visible=False, elem_classes="error-text"
)
gr.Examples(
examples=[
[example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD]
for example in VIDEO_EXAMPLES
],
inputs=[video_input, video_checkpoint, video_confidence_threshold],
outputs=[video_output, video_error_message],
fn=process_video,
cache_examples=False,
label="Select a video example to populate inputs",
)
# Dynamic visibility for URL input
use_url.change(
fn=lambda x: gr.update(visible=x),
inputs=use_url,
outputs=url_input,
)
# Image clear button
image_clear_button.click(
fn=lambda: (
None,
False,
"",
DEFAULT_CHECKPOINT,
DEFAULT_CONFIDENCE_THRESHOLD,
None,
gr.Markdown(visible=False),
),
outputs=[
image_input,
use_url,
url_input,
image_checkpoint,
image_confidence_threshold,
image_output,
image_error_message,
],
)
# Video clear button
video_clear_button.click(
fn=lambda: (
None,
DEFAULT_CHECKPOINT,
DEFAULT_CONFIDENCE_THRESHOLD,
None,
gr.Markdown(visible=False),
),
outputs=[
video_input,
video_checkpoint,
video_confidence_threshold,
video_output,
video_error_message,
],
)
# Image detect button
image_detect_button.click(
fn=detect_objects,
inputs=[
image_input,
image_checkpoint,
image_confidence_threshold,
use_url,
url_input,
],
outputs=[image_output, image_error_message],
)
# Video detect button
video_detect_button.click(
fn=process_video,
inputs=[video_input, video_checkpoint, video_confidence_threshold],
outputs=[video_output, video_error_message],
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()