File size: 6,674 Bytes
5a0b543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import tempfile
import numpy as np
import gradio as gr
from moviepy import VideoFileClip
import torch
import clip
import cv2
from PIL import Image
from scenedetect import VideoManager, SceneManager
from scenedetect.detectors import ContentDetector, AdaptiveDetector, ThresholdDetector, HistogramDetector, HashDetector

# Device options
DEVICE_OPTIONS = {
    "cpu": "cpu",
    "cuda": "cuda" if torch.cuda.is_available() else "cpu",
    "mps": "mps" if torch.backends.mps.is_available() else "cpu"
}

def load_clip_model(device):
    return clip.load("ViT-B/32", device=device)

# --- Video Processing Functions ---
def extract_frames(video_path, fps=2):
    cap = cv2.VideoCapture(video_path)
    frames = []
    frame_rate = int(cap.get(cv2.CAP_PROP_FPS) / fps)
    count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if count % frame_rate == 0:
            frames.append(frame)
        count += 1
    cap.release()
    return frames

def get_clip_features(frames, model, preprocess, device):
    features = []
    for frame in frames:
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        img_input = preprocess(img).unsqueeze(0).to(device)
        with torch.no_grad():
            feature = model.encode_image(img_input)
            features.append(feature.cpu().numpy()[0])
    return features

def compute_distance(a, b, method):
    if method == "cosine":
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    elif method == "l2":
        return np.linalg.norm(a - b)
    elif method == "l1":
        return np.sum(np.abs(a - b))
    else:
        return np.linalg.norm(a - b)

def find_match(clip_feats, ref_feats, threshold=0.3, similarity="l2"):
    len_clip = len(clip_feats)
    best_match = -1
    best_score = float('inf') if similarity != "cosine" else -float('inf')
    for i in range(len(ref_feats) - len_clip + 1):
        window = ref_feats[i:i + len_clip]
        dists = [compute_distance(a, b, similarity) for a, b in zip(clip_feats, window)]
        dist = np.mean(dists)
        if (similarity != "cosine" and dist < best_score) or (similarity == "cosine" and dist > best_score):
            best_score = dist
            best_match = i
    if (similarity != "cosine" and best_score < threshold) or (similarity == "cosine" and best_score > threshold):
        return best_match, best_score
    return -1, best_score

# Scene Detection
def get_detector(detector_name, threshold):
    if detector_name == "ContentDetector":
        return ContentDetector(threshold=threshold)
    elif detector_name == "AdaptiveDetector":
        return AdaptiveDetector()
    elif detector_name == "ThresholdDetector":
        return ThresholdDetector(threshold=threshold)
    elif detector_name == "HashDetector":
        return HashDetector(threshold=threshold)
    elif detector_name == "HistogramDetector":
        return HistogramDetector(threshold=threshold)
    else:
        return ContentDetector(threshold=threshold)

def detect_scenes(video_path, detector_name, threshold):
    video_manager = VideoManager([video_path])
    scene_manager = SceneManager()
    detector = get_detector(detector_name, threshold)
    scene_manager.add_detector(detector)
    video_manager.set_downscale_factor()
    video_manager.start()
    scene_manager.detect_scenes(frame_source=video_manager)
    scene_list = scene_manager.get_scene_list()
    return [(scene[0].get_seconds(), scene[1].get_seconds()) for scene in scene_list]

def find_scene_for_timestamp(scenes, match_time):
    for start, end in scenes:
        if start <= match_time <= end:
            return (start, end)
    return None

def extract_scene(video_path, scene_range, output_path):
    start_time, end_time = scene_range
    clip = VideoFileClip(video_path).subclipped(start_time, end_time)
    clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
    return output_path

# Main logic

def process_videos(clip_path, ref_path, match_threshold, scene_threshold, detector_type, similarity_type, device_type, output_path):
    device = DEVICE_OPTIONS.get(device_type, "cpu")
    model, preprocess = load_clip_model(device)

    clip_frames = extract_frames(clip_path)
    ref_frames = extract_frames(ref_path)

    clip_feats = get_clip_features(clip_frames, model, preprocess, device)
    ref_feats = get_clip_features(ref_frames, model, preprocess, device)

    match_index, score = find_match(clip_feats, ref_feats, match_threshold, similarity_type)

    if match_index == -1:
        return f"No match found (best score = {score:.4f})", None

    match_time = match_index * 0.5
    scenes = detect_scenes(ref_path, detector_type, scene_threshold)
    matched_scene = find_scene_for_timestamp(scenes, match_time)

    if not matched_scene:
        return "Match found, but no scene boundaries detected.", None
    output_path = os.path.join(output_path, "matched_scene.mp4")
    result_path = extract_scene(ref_path, matched_scene, output_path)

    return f"Match found at ~{match_time:.2f}s (score = {score:.4f})\nScene from {matched_scene[0]:.2f}s to {matched_scene[1]:.2f}s", result_path

# Gradio Interface
with tempfile.TemporaryDirectory() as tmpdir:
    iface = gr.Interface(
        fn=process_videos,
        inputs=[
            gr.Video(label="Clip Video"),
            gr.Video(label="Reference Video"),
            gr.Slider(0.1, 100.0, value=0.3, label="Matching Threshold (lower = stricter, cosine = higher = better)"),
            gr.Slider(0.01, 100, value=30, step=1, label="Scene Detection Threshold"),
            gr.Dropdown([
                "ContentDetector", "AdaptiveDetector", "ThresholdDetector", "HistogramDetector", "HashDetector"
            ], value="ContentDetector", label="Scene Detector Type"),
            gr.Dropdown(["l2", "l1", "cosine"], value="l2", label="Similarity Metric"),
            gr.Dropdown(["cpu", "cuda", "mps"], value="cpu", label="Processing Device"),
            gr.Text(value=tmpdir,visible=False)
        ],
        outputs=[
            gr.Text(label="Match Info"),
            gr.Video(label="Matched Scene")
        ],
        title="AI Video Clip Matcher",
        description="Upload a short video clip and a reference video. The system will try to find where the clip appears in the reference video and extract the full scene around it."
    )
    
# --- Launch the App ---
if __name__ == "__main__":
    print("Launching Gradio interface...")

    # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
    # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
    iface.launch()