japanese-denim commited on
Commit
81c58dc
·
1 Parent(s): f53f018

uploaded required files

Browse files
Files changed (3) hide show
  1. app.py +100 -0
  2. requirements.txt +0 -0
  3. test.py +15 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import math
4
+ import clip
5
+ import torch
6
+ from flask import Flask, request, jsonify
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ # load OpenAI CLIP model
10
+ model, preprocess = clip.load("ViT-B/32", device=device)
11
+ app = Flask(__name__)
12
+
13
+ # define the API endpoint
14
+ @app.route('/search_videos', methods=['POST'])
15
+ def search_videos():
16
+ # get the search item from the request parameters
17
+ search_item = request.form.get('search_item')
18
+
19
+ # get the video files from the request parameters
20
+ video_files = request.files.getlist('video_files')
21
+
22
+ # store the matching videos
23
+ matching_videos = []
24
+
25
+ # loop through all videos
26
+ for video_file in video_files:
27
+ # no. of frames to skip
28
+ n = 120
29
+
30
+ # store the video frames
31
+ video_frames = []
32
+
33
+ # open the video
34
+ capture = cv2.VideoCapture(video_file)
35
+ fps = capture.get(cv2.CAP_PROP_FPS)
36
+
37
+ current_frame = 0
38
+ # read the current frame
39
+ ret, frame = capture.read()
40
+ while capture.isOpened() and ret:
41
+ ret,frame = capture.read()
42
+
43
+ if ret:
44
+ video_frames.append(Image.fromarray(frame[:, :, ::-1]))
45
+
46
+ # skip n frames
47
+ current_frame += n
48
+ capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
49
+
50
+ # ENCODE THE FRAMES
51
+ batch_size = 256
52
+ batches = math.ceil(len(video_frames) / batch_size)
53
+
54
+ # store the encoded features
55
+ video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
56
+
57
+ # process each batch
58
+ for i in range(batches):
59
+ # get the relevant frames
60
+ batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
61
+
62
+ # preprocess the frames for the batch
63
+ batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
64
+
65
+ # encode with CLIP and normalize
66
+ with torch.no_grad():
67
+ batch_features = model.encode_image(batch_preprocessed)
68
+ batch_features /= batch_features.norm(dim=-1, keepdim=True)
69
+
70
+ # append the batch to the list containing all features
71
+ video_features = torch.cat((video_features, batch_features))
72
+
73
+ # determine if video contains the search item
74
+ if contain_search_item(video_frames, video_features, search_item):
75
+ matching_videos.append(video_file)
76
+ # break
77
+
78
+ # return the list of matching videos
79
+ return matching_videos
80
+
81
+
82
+ def contain_search_item(video_frames, video_features, search_query):
83
+ # encode and normalize the search query using CLIP
84
+ with torch.no_grad():
85
+ text_features = model.encode_text(clip.tokenize(search_query).to(device))
86
+ text_features /= text_features.norm(dim=-1, keepdim=True)
87
+
88
+ # compute the similarity between the search query and each frame
89
+ similarities = (100.0 * video_features @ text_features.T)
90
+ values, best_photo_idx = similarities.topk(1, dim=0)
91
+
92
+ for frame_id in best_photo_idx:
93
+ frame = video_frames[frame_id]
94
+
95
+ return frame
96
+
97
+
98
+ if __name__ == '__main__':
99
+ app.run(debug=True)
100
+
requirements.txt ADDED
Binary file (134 Bytes). View file
 
test.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ # set the search item
4
+ search_item = 'a duck'
5
+ # provide a list of videos
6
+ video_files = ["video.mp4", "video1.mp4"]
7
+
8
+ # send a POST request to the API endpoint
9
+ response = requests.post('http://localhost:5000/search_videos',
10
+ data={'search_item': search_item},
11
+ files=[("videos", open(video, "rb")) for video in video_files])
12
+
13
+
14
+ # print the list of matching videos returned by the API
15
+ print(response.json())