Spaces:
Build error
Build error
japanese-denim
commited on
Commit
·
81c58dc
1
Parent(s):
f53f018
uploaded required files
Browse files- app.py +100 -0
- requirements.txt +0 -0
- test.py +15 -0
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from PIL import Image
|
3 |
+
import math
|
4 |
+
import clip
|
5 |
+
import torch
|
6 |
+
from flask import Flask, request, jsonify
|
7 |
+
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
# load OpenAI CLIP model
|
10 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
11 |
+
app = Flask(__name__)
|
12 |
+
|
13 |
+
# define the API endpoint
|
14 |
+
@app.route('/search_videos', methods=['POST'])
|
15 |
+
def search_videos():
|
16 |
+
# get the search item from the request parameters
|
17 |
+
search_item = request.form.get('search_item')
|
18 |
+
|
19 |
+
# get the video files from the request parameters
|
20 |
+
video_files = request.files.getlist('video_files')
|
21 |
+
|
22 |
+
# store the matching videos
|
23 |
+
matching_videos = []
|
24 |
+
|
25 |
+
# loop through all videos
|
26 |
+
for video_file in video_files:
|
27 |
+
# no. of frames to skip
|
28 |
+
n = 120
|
29 |
+
|
30 |
+
# store the video frames
|
31 |
+
video_frames = []
|
32 |
+
|
33 |
+
# open the video
|
34 |
+
capture = cv2.VideoCapture(video_file)
|
35 |
+
fps = capture.get(cv2.CAP_PROP_FPS)
|
36 |
+
|
37 |
+
current_frame = 0
|
38 |
+
# read the current frame
|
39 |
+
ret, frame = capture.read()
|
40 |
+
while capture.isOpened() and ret:
|
41 |
+
ret,frame = capture.read()
|
42 |
+
|
43 |
+
if ret:
|
44 |
+
video_frames.append(Image.fromarray(frame[:, :, ::-1]))
|
45 |
+
|
46 |
+
# skip n frames
|
47 |
+
current_frame += n
|
48 |
+
capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
|
49 |
+
|
50 |
+
# ENCODE THE FRAMES
|
51 |
+
batch_size = 256
|
52 |
+
batches = math.ceil(len(video_frames) / batch_size)
|
53 |
+
|
54 |
+
# store the encoded features
|
55 |
+
video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
|
56 |
+
|
57 |
+
# process each batch
|
58 |
+
for i in range(batches):
|
59 |
+
# get the relevant frames
|
60 |
+
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
|
61 |
+
|
62 |
+
# preprocess the frames for the batch
|
63 |
+
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
|
64 |
+
|
65 |
+
# encode with CLIP and normalize
|
66 |
+
with torch.no_grad():
|
67 |
+
batch_features = model.encode_image(batch_preprocessed)
|
68 |
+
batch_features /= batch_features.norm(dim=-1, keepdim=True)
|
69 |
+
|
70 |
+
# append the batch to the list containing all features
|
71 |
+
video_features = torch.cat((video_features, batch_features))
|
72 |
+
|
73 |
+
# determine if video contains the search item
|
74 |
+
if contain_search_item(video_frames, video_features, search_item):
|
75 |
+
matching_videos.append(video_file)
|
76 |
+
# break
|
77 |
+
|
78 |
+
# return the list of matching videos
|
79 |
+
return matching_videos
|
80 |
+
|
81 |
+
|
82 |
+
def contain_search_item(video_frames, video_features, search_query):
|
83 |
+
# encode and normalize the search query using CLIP
|
84 |
+
with torch.no_grad():
|
85 |
+
text_features = model.encode_text(clip.tokenize(search_query).to(device))
|
86 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
87 |
+
|
88 |
+
# compute the similarity between the search query and each frame
|
89 |
+
similarities = (100.0 * video_features @ text_features.T)
|
90 |
+
values, best_photo_idx = similarities.topk(1, dim=0)
|
91 |
+
|
92 |
+
for frame_id in best_photo_idx:
|
93 |
+
frame = video_frames[frame_id]
|
94 |
+
|
95 |
+
return frame
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
app.run(debug=True)
|
100 |
+
|
requirements.txt
ADDED
Binary file (134 Bytes). View file
|
|
test.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
# set the search item
|
4 |
+
search_item = 'a duck'
|
5 |
+
# provide a list of videos
|
6 |
+
video_files = ["video.mp4", "video1.mp4"]
|
7 |
+
|
8 |
+
# send a POST request to the API endpoint
|
9 |
+
response = requests.post('http://localhost:5000/search_videos',
|
10 |
+
data={'search_item': search_item},
|
11 |
+
files=[("videos", open(video, "rb")) for video in video_files])
|
12 |
+
|
13 |
+
|
14 |
+
# print the list of matching videos returned by the API
|
15 |
+
print(response.json())
|