Spaces:
Runtime error
Runtime error
First update
Browse files- app.py +306 -80
- bridgetower_custom.py +183 -0
- requirements.txt +5 -1
app.py
CHANGED
@@ -1,90 +1,316 @@
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
|
|
|
|
|
|
4 |
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
if frame_count % (fps * sample_rate) == 0:
|
41 |
-
frame = Image.fromarray(frame)
|
42 |
-
score = process_frame(frame, text)
|
43 |
-
# print(f"{frame_count} {scores}")
|
44 |
-
|
45 |
-
if float(score[text]) > min_score:
|
46 |
-
if clip_started:
|
47 |
-
end_time = frame_count / fps
|
48 |
-
else:
|
49 |
-
clip_started = True
|
50 |
-
start_time = frame_count / fps
|
51 |
-
end_time = start_time
|
52 |
-
start_score = score[text]
|
53 |
-
clip_images.append(frame)
|
54 |
-
elif clip_started:
|
55 |
-
clip_started = False
|
56 |
-
end_time = frame_count / fps
|
57 |
-
clips.append((start_score, start_time, end_time))
|
58 |
-
frame_count += 1
|
59 |
-
return clip_images, clips
|
60 |
-
|
61 |
-
|
62 |
-
# Inputs
|
63 |
-
video = gr.Video(label="Video")
|
64 |
-
text = gr.Text(label="Text query")
|
65 |
-
sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)")
|
66 |
-
min_score = gr.Number(value=3, label="Minimum score")
|
67 |
-
|
68 |
-
# Output
|
69 |
-
gallery = gr.Gallery(label="Images")
|
70 |
-
clips = gr.Text(label="Clips (score, start time, end time)")
|
71 |
|
72 |
description = "This Space lets you run semantic search on a video."
|
73 |
|
74 |
-
|
75 |
-
description
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# In[]:
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
import cv2
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
|
11 |
|
12 |
+
from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC
|
13 |
+
|
14 |
+
import pickle
|
15 |
+
from tqdm import tqdm
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import re
|
20 |
+
import urllib.parse
|
21 |
+
import faiss
|
22 |
+
|
23 |
+
import webvtt
|
24 |
+
import json
|
25 |
+
|
26 |
+
from pytube import YouTube
|
27 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
28 |
+
from youtube_transcript_api.formatters import WebVTTFormatter
|
29 |
+
|
30 |
+
device = 'cpu'
|
31 |
+
model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc'
|
32 |
+
model = BridgeTowerForITC.from_pretrained(model_name).to(device)
|
33 |
+
text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device)
|
34 |
+
|
35 |
+
processor = BridgeTowerProcessor.from_pretrained(model_name)
|
36 |
+
|
37 |
+
|
38 |
+
def download_video(video_url, path='/tmp/'):
|
39 |
+
|
40 |
+
yt = YouTube(video_url)
|
41 |
+
yt = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
42 |
+
if not os.path.exists(path):
|
43 |
+
os.makedirs(path)
|
44 |
+
filepath = os.path.join(path, yt.default_filename)
|
45 |
+
if not os.path.exists(filepath):
|
46 |
+
print('Downloading video from YouTube...')
|
47 |
+
yt.download(path)
|
48 |
+
return filepath
|
49 |
+
|
50 |
+
|
51 |
+
# Get transcript in webvtt
|
52 |
+
def get_transcript_vtt(video_id, path='/tmp'):
|
53 |
+
filepath = os.path.join(path,'test_vm.vtt')
|
54 |
+
if os.path.exists(filepath):
|
55 |
+
return filepath
|
56 |
+
|
57 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
58 |
+
formatter = WebVTTFormatter()
|
59 |
+
webvtt_formatted = formatter.format_transcript(transcript)
|
60 |
+
|
61 |
+
with open(filepath, 'w', encoding='utf-8') as webvtt_file:
|
62 |
+
webvtt_file.write(webvtt_formatted)
|
63 |
+
webvtt_file.close()
|
64 |
+
|
65 |
+
return filepath
|
66 |
+
|
67 |
+
# https://stackoverflow.com/a/57781047
|
68 |
+
# Resizes a image and maintains aspect ratio
|
69 |
+
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
70 |
+
# Grab the image size and initialize dimensions
|
71 |
+
dim = None
|
72 |
+
(h, w) = image.shape[:2]
|
73 |
+
|
74 |
+
# Return original image if no need to resize
|
75 |
+
if width is None and height is None:
|
76 |
+
return image
|
77 |
+
|
78 |
+
# We are resizing height if width is none
|
79 |
+
if width is None:
|
80 |
+
# Calculate the ratio of the height and construct the dimensions
|
81 |
+
r = height / float(h)
|
82 |
+
dim = (int(w * r), height)
|
83 |
+
# We are resizing width if height is none
|
84 |
+
else:
|
85 |
+
# Calculate the ratio of the width and construct the dimensions
|
86 |
+
r = width / float(w)
|
87 |
+
dim = (width, int(h * r))
|
88 |
+
|
89 |
+
# Return the resized image
|
90 |
+
return cv2.resize(image, dim, interpolation=inter)
|
91 |
+
|
92 |
+
def time_to_frame(time, fps):
|
93 |
+
'''
|
94 |
+
convert time in seconds into frame number
|
95 |
+
'''
|
96 |
+
return time * fps - 1
|
97 |
+
|
98 |
+
def str2time(strtime):
|
99 |
+
strtime = strtime.strip('"')
|
100 |
+
hrs, mins, seconds = [float(c) for c in strtime.split(':')]
|
101 |
+
|
102 |
+
total_seconds = hrs * 60**2 + mins * 60 + seconds
|
103 |
+
|
104 |
+
return total_seconds
|
105 |
+
|
106 |
+
def collate_fn(batch_list):
|
107 |
+
batch = {}
|
108 |
+
batch['input_ids'] = pad_sequence([encoding['input_ids'].squeeze(0) for encoding in batch_list], batch_first=True)
|
109 |
+
batch['attention_mask'] = pad_sequence([encoding['attention_mask'].squeeze(0) for encoding in batch_list], batch_first=True)
|
110 |
+
batch['pixel_values'] = torch.cat([encoding['pixel_values'] for encoding in batch_list], dim=0)
|
111 |
+
batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
|
112 |
+
return batch
|
113 |
+
|
114 |
+
def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2):
|
115 |
+
if os.path.exists(os.path.join(output, 'embeddings.pkl')):
|
116 |
+
return
|
117 |
+
|
118 |
+
os.makedirs(output, exist_ok=True)
|
119 |
+
os.makedirs(os.path.join(output, 'frames'), exist_ok=True)
|
120 |
+
os.makedirs(os.path.join(output, 'frames_thumb'), exist_ok=True)
|
121 |
+
|
122 |
+
count = 0
|
123 |
+
|
124 |
+
vidcap = cv2.VideoCapture(video_path)
|
125 |
+
|
126 |
+
# Get the frames per second
|
127 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
128 |
+
|
129 |
+
# Get the total numer of frames in the video.
|
130 |
+
frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
|
131 |
+
|
132 |
+
print(fps, frame_count)
|
133 |
+
|
134 |
+
frame_number = 0
|
135 |
+
|
136 |
+
count = 0
|
137 |
+
anno = []
|
138 |
+
|
139 |
+
embeddings = []
|
140 |
+
batch_list = []
|
141 |
+
|
142 |
+
for idx, caption in enumerate(webvtt.read(subtitles)):
|
143 |
+
st_time = str2time(caption.start)
|
144 |
+
ed_time = str2time(caption.end)
|
145 |
+
|
146 |
+
mid_time = (ed_time + st_time) / 2
|
147 |
+
text = caption.text.replace('\n', ' ')
|
148 |
+
|
149 |
+
if expanded :
|
150 |
+
raise NotImplementedError
|
151 |
+
|
152 |
+
frame_no = time_to_frame(mid_time, fps)
|
153 |
+
|
154 |
+
print('Read a new frame: ', idx, mid_time, frame_no, text)
|
155 |
+
vidcap.set(1, frame_no) # added this line
|
156 |
+
success, image = vidcap.read()
|
157 |
+
if success:
|
158 |
+
img_fname = f'{video_id}_{idx:06d}'
|
159 |
+
img_fpath = os.path.join(output, 'frames', img_fname + '.jpg')
|
160 |
+
image = maintain_aspect_ratio_resize(image, height=350) # save frame as JPEG file
|
161 |
+
cv2.imwrite( img_fpath, image) # save frame as JPEG file
|
162 |
+
|
163 |
+
count += 1
|
164 |
+
anno.append({
|
165 |
+
'image_id': idx,
|
166 |
+
'img_fname': img_fname,
|
167 |
+
'caption': text,
|
168 |
+
'time': mid_time,
|
169 |
+
'frame_no': frame_no
|
170 |
+
})
|
171 |
+
|
172 |
+
else:
|
173 |
break
|
174 |
+
|
175 |
+
encoding = processor(image, text, return_tensors="pt").to(device)
|
176 |
+
encoding['text'] = text
|
177 |
+
encoding['image_filepath'] = img_fpath
|
178 |
+
encoding['start_time'] = caption.start
|
179 |
+
|
180 |
+
batch_list.append(encoding)
|
181 |
+
|
182 |
+
if len(batch_list) == batch_size:
|
183 |
+
batch = collate_fn(batch_list)
|
184 |
+
with torch.no_grad():
|
185 |
+
outputs = model(**batch, output_hidden_states=True)
|
186 |
+
|
187 |
+
for i in range(batch_size):
|
188 |
+
embeddings.append({
|
189 |
+
'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(),
|
190 |
+
'text': batch_list[i]['text'],
|
191 |
+
'image_filepath': batch_list[i]['image_filepath'],
|
192 |
+
'start_time': batch_list[i]['start_time'],
|
193 |
+
})
|
194 |
+
batch_list = []
|
195 |
+
|
196 |
+
if batch_list:
|
197 |
+
batch = collate_fn(batch_list)
|
198 |
+
with torch.no_grad():
|
199 |
+
outputs = model(**batch, output_hidden_states=True)
|
200 |
+
|
201 |
+
for i in range(len(batch_list)):
|
202 |
+
embeddings.append({
|
203 |
+
'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(),
|
204 |
+
'text': batch_list[i]['text'],
|
205 |
+
'image_filepath': batch_list[i]['image_filepath'],
|
206 |
+
'start_time': batch_list[i]['start_time'],
|
207 |
+
})
|
208 |
+
|
209 |
+
with open(os.path.join(output, 'annotations.json'), 'w') as fh:
|
210 |
+
json.dump(anno, fh)
|
211 |
+
|
212 |
+
with open(os.path.join(output, 'embeddings.pkl'), 'wb') as fh:
|
213 |
+
pickle.dump(embeddings, fh)
|
214 |
+
|
215 |
+
def run_query(video_id, text_query, path='/tmp'):
|
216 |
+
|
217 |
+
embeddings_filepath = os.path.join(path, 'embeddings.pkl')
|
218 |
+
faiss_filepath = os.path.join(path, 'faiss_index.pkl')
|
219 |
+
|
220 |
+
embeddings = pickle.load(open(embeddings_filepath, 'rb'))
|
221 |
+
|
222 |
+
if os.path.exists(faiss_filepath):
|
223 |
+
faiss_index = pickle.load(open(faiss_filepath, 'rb'))
|
224 |
+
else :
|
225 |
+
embs = [emb['embeddings'] for emb in embeddings]
|
226 |
+
vectors = np.stack(embs, axis=0)
|
227 |
+
num_vectors, vector_dim = vectors.shape
|
228 |
+
faiss_index = faiss.IndexFlatIP(vector_dim)
|
229 |
+
faiss_index.add(vectors)
|
230 |
+
pickle.dump(faiss_index, open(faiss_filepath, 'wb'))
|
231 |
+
|
232 |
+
print('Processing query')
|
233 |
+
encoding = processor.tokenizer(text_query, return_tensors="pt").to(device)
|
234 |
+
with torch.no_grad():
|
235 |
+
outputs = text_model(**encoding)
|
236 |
+
emb_query = outputs.cpu().numpy()
|
237 |
+
print('Running FAISS search')
|
238 |
+
_, I = faiss_index.search(emb_query, 6)
|
239 |
+
|
240 |
+
clip_images = [embeddings[idx]['image_filepath'] for idx in I[0]]
|
241 |
+
transcripts = [f"({embeddings[idx]['start_time']}) {embeddings[idx]['text']}" for idx in I[0]]
|
242 |
+
return clip_images, transcripts
|
243 |
+
|
244 |
+
|
245 |
+
def get_video_id_from_url(video_url):
|
246 |
+
"""
|
247 |
+
Examples:
|
248 |
+
- http://youtu.be/SA2iWivDJiE
|
249 |
+
- http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
|
250 |
+
- http://www.youtube.com/embed/SA2iWivDJiE
|
251 |
+
- http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US
|
252 |
+
"""
|
253 |
+
import urllib.parse
|
254 |
+
url = urllib.parse.urlparse(video_url)
|
255 |
+
if url.hostname == 'youtu.be':
|
256 |
+
return url.path[1:]
|
257 |
+
if url.hostname in ('www.youtube.com', 'youtube.com'):
|
258 |
+
if url.path == '/watch':
|
259 |
+
p = urllib.parse.parse_qs(url.query)
|
260 |
+
return p['v'][0]
|
261 |
+
if url.path[:7] == '/embed/':
|
262 |
+
return url.path.split('/')[2]
|
263 |
+
if url.path[:3] == '/v/':
|
264 |
+
return url.path.split('/')[2]
|
265 |
+
|
266 |
+
|
267 |
+
return None
|
268 |
+
|
269 |
+
|
270 |
+
def process(video_url, text_query):
|
271 |
+
tmp_dir = os.path.join(os.getcwd(), 'cache')
|
272 |
+
video_id = get_video_id_from_url(video_url)
|
273 |
+
output_dir = os.path.join(tmp_dir, video_id)
|
274 |
+
video_file = download_video(video_url, path=output_dir)
|
275 |
+
subtitles = get_transcript_vtt(video_id, path=output_dir)
|
276 |
+
extract_images_and_embeds(video_id=video_id,
|
277 |
+
video_path=video_file,
|
278 |
+
subtitles=subtitles,
|
279 |
+
output=output_dir,
|
280 |
+
expanded=False,
|
281 |
+
batch_size=8,
|
282 |
+
)
|
283 |
+
frame_paths, transcripts = run_query(video_id, text_query, path=output_dir)
|
284 |
+
return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
|
285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
description = "This Space lets you run semantic search on a video."
|
288 |
|
289 |
+
with gr.Blocks() as demo:
|
290 |
+
gr.Markdown(description)
|
291 |
+
with gr.Row():
|
292 |
+
with gr.Column():
|
293 |
+
video_url = gr.Text(label="Youtube url")
|
294 |
+
text_query = gr.Text(label="Text query")
|
295 |
+
btn = gr.Button("Run query")
|
296 |
+
video_player = gr.Video(label="Video")
|
297 |
+
|
298 |
+
with gr.Row():
|
299 |
+
gallery = gr.Gallery(label="Images").style(grid=6)
|
300 |
+
|
301 |
+
gr.Examples(
|
302 |
+
examples=[
|
303 |
+
['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
|
304 |
+
['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake on floor'],
|
305 |
+
['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'cat woman'],
|
306 |
+
['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
|
307 |
+
],
|
308 |
+
inputs=[video_url, text_query],
|
309 |
+
)
|
310 |
+
|
311 |
+
btn.click(fn=process,
|
312 |
+
inputs=[video_url, text_query],
|
313 |
+
outputs=[video_player, gallery],
|
314 |
+
)
|
315 |
+
|
316 |
+
demo.launch(share=True, server_port=25566)
|
bridgetower_custom.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
12 |
+
|
13 |
+
from transformers import BridgeTowerPreTrainedModel, BridgeTowerModel
|
14 |
+
from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerTextModel
|
15 |
+
|
16 |
+
class LayerNorm(nn.LayerNorm):
|
17 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
18 |
+
|
19 |
+
def forward(self, x: torch.Tensor):
|
20 |
+
orig_type = x.dtype
|
21 |
+
ret = super().forward(x.type(torch.float32))
|
22 |
+
return ret.type(orig_type)
|
23 |
+
|
24 |
+
class BridgeTowerImageFeatureExtractor(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
patch_size=14,
|
28 |
+
width=1024,
|
29 |
+
resolution_after=294,
|
30 |
+
ckpt_path=None,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
35 |
+
|
36 |
+
scale = width ** -0.5
|
37 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
38 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((resolution_after // patch_size) ** 2 + 1, width))
|
39 |
+
self.ln_pre = LayerNorm(width)
|
40 |
+
|
41 |
+
if ckpt_path is not None:
|
42 |
+
sd = torch.load(ckpt_path)
|
43 |
+
if 'state_dict' in sd:
|
44 |
+
sd = sd["state_dict"]
|
45 |
+
print(f'Loading feature extractor checkpoint from {ckpt_path}')
|
46 |
+
self.load_state_dict(sd)
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor):
|
49 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
50 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
51 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
52 |
+
t=self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
53 |
+
x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
54 |
+
x = x + self.positional_embedding.to(x.dtype)
|
55 |
+
x = self.ln_pre(x)
|
56 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class BridgeTowerITCHead(nn.Module):
|
61 |
+
def __init__(self, hidden_size, embed_size):
|
62 |
+
super().__init__()
|
63 |
+
self.fc = nn.Linear(hidden_size, embed_size)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.fc(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class _BridgeTowerTextModelWrapper(nn.Module):
|
71 |
+
def __init__(self, config):
|
72 |
+
super().__init__()
|
73 |
+
self.text_model = BridgeTowerTextModel(config)
|
74 |
+
|
75 |
+
def forward(self, **kwargs):
|
76 |
+
return self.text_model(**kwargs)
|
77 |
+
|
78 |
+
|
79 |
+
class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel):
|
80 |
+
def __init__(self, config):
|
81 |
+
super().__init__(config)
|
82 |
+
|
83 |
+
self.bridgetower = _BridgeTowerTextModelWrapper(config.text_config)
|
84 |
+
self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self,
|
88 |
+
input_ids: Optional[torch.LongTensor] = None,
|
89 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
90 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
91 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
92 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
93 |
+
output_attentions: Optional[bool] = None,
|
94 |
+
output_hidden_states: Optional[bool] = None,
|
95 |
+
return_dict: Optional[bool] = None,
|
96 |
+
labels: Optional[torch.LongTensor] = None,
|
97 |
+
):
|
98 |
+
|
99 |
+
outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask)
|
100 |
+
final_hidden_cls = outputs.last_hidden_state[:,0,:]
|
101 |
+
final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
|
102 |
+
|
103 |
+
return final_hidden_cls
|
104 |
+
|
105 |
+
|
106 |
+
class BridgeTowerForITC(BridgeTowerPreTrainedModel):
|
107 |
+
def __init__(self, config):
|
108 |
+
super().__init__(config)
|
109 |
+
|
110 |
+
self.bridgetower = BridgeTowerModel(config)
|
111 |
+
|
112 |
+
self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
|
113 |
+
self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
|
114 |
+
self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size)
|
115 |
+
|
116 |
+
# Initialize weights and apply final processing
|
117 |
+
self.post_init()
|
118 |
+
|
119 |
+
def forward(
|
120 |
+
self,
|
121 |
+
input_ids: Optional[torch.LongTensor] = None,
|
122 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
123 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
124 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
125 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
126 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
127 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
128 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
129 |
+
output_attentions: Optional[bool] = None,
|
130 |
+
output_hidden_states: Optional[bool] = None,
|
131 |
+
return_dict: Optional[bool] = None,
|
132 |
+
labels: Optional[torch.LongTensor] = None,
|
133 |
+
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
|
134 |
+
|
135 |
+
assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC'
|
136 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
137 |
+
|
138 |
+
outputs = self.bridgetower(
|
139 |
+
input_ids,
|
140 |
+
attention_mask=attention_mask,
|
141 |
+
token_type_ids=token_type_ids,
|
142 |
+
pixel_values=pixel_values,
|
143 |
+
pixel_mask=pixel_mask,
|
144 |
+
head_mask=head_mask,
|
145 |
+
inputs_embeds=inputs_embeds,
|
146 |
+
image_embeds=image_embeds,
|
147 |
+
output_attentions=output_attentions,
|
148 |
+
output_hidden_states=output_hidden_states,
|
149 |
+
return_dict=return_dict,
|
150 |
+
)
|
151 |
+
|
152 |
+
pooler_output = outputs.pooler_output if return_dict else outputs[2]
|
153 |
+
|
154 |
+
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states
|
155 |
+
|
156 |
+
final_hidden_txt = hidden_states_txt[-1]
|
157 |
+
final_hidden_img = hidden_states_img[-1]
|
158 |
+
|
159 |
+
image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(final_hidden_img)
|
160 |
+
image_token_type_embeddings = self.bridgetower.token_type_embeddings(
|
161 |
+
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
|
162 |
+
).expand_as(image_embeds_with_ln)
|
163 |
+
|
164 |
+
final_hidden_img = (
|
165 |
+
self.bridgetower.cross_modal_image_transform(image_embeds_with_ln)
|
166 |
+
+ image_token_type_embeddings
|
167 |
+
)
|
168 |
+
|
169 |
+
final_hidden_txt = F.normalize(self.itc_text_head(final_hidden_txt[:,0,:]), dim=-1, p=2)
|
170 |
+
final_hidden_img = F.normalize(self.itc_image_head(final_hidden_img[:,0,:]), dim=-1, p=2)
|
171 |
+
final_hidden_cross = F.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
|
172 |
+
|
173 |
+
logits = torch.stack([final_hidden_txt, final_hidden_img, final_hidden_cross], dim=-2)
|
174 |
+
|
175 |
+
if not return_dict:
|
176 |
+
return tuple(logits)
|
177 |
+
|
178 |
+
return SequenceClassifierOutput(
|
179 |
+
loss=None,
|
180 |
+
logits=logits,
|
181 |
+
hidden_states=outputs.hidden_states,
|
182 |
+
attentions=outputs.attentions,
|
183 |
+
)
|
requirements.txt
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
git+https://github.com/huggingface/transformers
|
2 |
torch
|
3 |
requests
|
4 |
-
Pillow
|
|
|
|
|
|
|
|
|
|
1 |
git+https://github.com/huggingface/transformers
|
2 |
torch
|
3 |
requests
|
4 |
+
Pillow
|
5 |
+
youtube-transcript-api
|
6 |
+
faiss-cpu
|
7 |
+
webvtt
|
8 |
+
pytube
|