File size: 4,745 Bytes
53319ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec41aa1
53319ee
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from transformers import pipeline
import cv2
import numpy as np
import tempfile
import os

# Initialize the object detection pipeline
object_detector = pipeline("object-detection",
                         model="facebook/detr-resnet-50")

def draw_bounding_boxes(frame, detections):
    """
    Draws bounding boxes on the video frame based on the detections.
    """
    # Convert numpy array to PIL Image
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(frame_rgb)
    draw = ImageDraw.Draw(pil_image)
    
    # Use default font
    font = ImageFont.load_default()

    for detection in detections:
        box = detection['box']
        xmin = int(box['xmin'])
        ymin = int(box['ymin'])
        xmax = int(box['xmax'])
        ymax = int(box['ymax'])

        # Draw the bounding box
        draw.rectangle([(xmin, ymin), (xmax, ymax)], outline="red", width=3)

        # Create label with score
        label = detection['label']
        score = detection['score']
        text = f"{label} {score:.2f}"

        # Draw text with background rectangle for visibility
        text_bbox = draw.textbbox((xmin, ymin), text, font=font)
        draw.rectangle([
            (text_bbox[0], text_bbox[1]),
            (text_bbox[2], text_bbox[3])
        ], fill="red")
        draw.text((xmin, ymin), text, fill="white", font=font)

    # Convert back to numpy array
    frame_with_boxes = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
    return frame_with_boxes

def process_video(video_path):
    """
    Process the video file and return the path to the processed video
    """
    try:
        # Open the video file
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return None

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Create temporary file for output video
        temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
        output_path = temp_output.name
        temp_output.close()

        # Initialize video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

        frame_count = 0
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Process every nth frame to speed up processing
        process_every_n_frames = 2  # Adjust this value to process more or fewer frames
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame_count += 1
            
            # Only process every nth frame
            if frame_count % process_every_n_frames == 0:
                # Convert frame to RGB for the model
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Detect objects
                detections = object_detector(frame_rgb)
                
                # Draw bounding boxes
                frame = draw_bounding_boxes(frame, detections)
            
            # Write the frame
            out.write(frame)
            
            # Print progress
            progress = (frame_count / total_frames) * 100
            print(f"Processing: {progress:.1f}% complete", end='\r')

        # Release everything
        cap.release()
        out.release()
        
        return output_path
        
    except Exception as e:
        print(f"Error processing video: {str(e)}")
        return None

def detect_objects_in_video(video):
    """
    Gradio interface function for video object detection
    """
    if video is None:
        return None
    
    try:
        # Process the video
        output_path = process_video(video)
        if output_path is None:
            return None
            
        return output_path
        
    except Exception as e:
        print(f"Error during video processing: {str(e)}")
        return None

# Create the Gradio interface
demo = gr.Interface(
    fn=detect_objects_in_video,
    inputs=[
        gr.Video(label="Upload Video")
    ],
    outputs=[
        gr.Video(label="Processed Video")
    ],
    title="Video Object Detection",
    description="""
    Upload a video to detect and track objects within it. 
    The application will process the video and draw bounding boxes around detected objects 
    with their labels and confidence scores.
    Note: Processing may take some time depending on the video length.
    """
)

if __name__ == "__main__":
    demo.launch()