shaoyent commited on
Commit
a1ebdce
1 Parent(s): 54597df

First update

Browse files
Files changed (3) hide show
  1. app.py +306 -80
  2. bridgetower_custom.py +183 -0
  3. requirements.txt +5 -1
app.py CHANGED
@@ -1,90 +1,316 @@
 
 
 
1
  import cv2
2
  import gradio as gr
3
  from PIL import Image
 
 
 
4
  from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
5
 
6
- model_id = "BridgeTower/bridgetower-large-itm-mlm-gaudi"
7
- processor = BridgeTowerProcessor.from_pretrained(model_id)
8
- model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id)
9
-
10
- # Process a frame
11
- def process_frame(image, texts):
12
- scores = {}
13
- texts = texts.split(",")
14
- for t in texts:
15
- encoding = processor(image, t, return_tensors="pt")
16
- outputs = model(**encoding)
17
- scores[t] = "{:.2f}".format(outputs.logits[0, 1].item())
18
- # sort scores in descending order
19
- scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
20
- return scores
21
-
22
-
23
- # Process a video
24
- def process(video, text, sample_rate, min_score):
25
- video = cv2.VideoCapture(video)
26
- fps = round(video.get(cv2.CAP_PROP_FPS))
27
- frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
28
- length = frames // fps
29
- print(f"{fps} fps, {frames} frames, {length} seconds")
30
-
31
- frame_count = 0
32
- clips = []
33
- clip_images = []
34
- clip_started = False
35
- while True:
36
- ret, frame = video.read()
37
- if not ret:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- if frame_count % (fps * sample_rate) == 0:
41
- frame = Image.fromarray(frame)
42
- score = process_frame(frame, text)
43
- # print(f"{frame_count} {scores}")
44
-
45
- if float(score[text]) > min_score:
46
- if clip_started:
47
- end_time = frame_count / fps
48
- else:
49
- clip_started = True
50
- start_time = frame_count / fps
51
- end_time = start_time
52
- start_score = score[text]
53
- clip_images.append(frame)
54
- elif clip_started:
55
- clip_started = False
56
- end_time = frame_count / fps
57
- clips.append((start_score, start_time, end_time))
58
- frame_count += 1
59
- return clip_images, clips
60
-
61
-
62
- # Inputs
63
- video = gr.Video(label="Video")
64
- text = gr.Text(label="Text query")
65
- sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)")
66
- min_score = gr.Number(value=3, label="Minimum score")
67
-
68
- # Output
69
- gallery = gr.Gallery(label="Images")
70
- clips = gr.Text(label="Clips (score, start time, end time)")
71
 
72
  description = "This Space lets you run semantic search on a video."
73
 
74
- iface = gr.Interface(
75
- description=description,
76
- fn=process,
77
- inputs=[video, text, sample_rate, min_score],
78
- outputs=[gallery, clips],
79
- examples=[
80
- [
81
- "video.mp4",
82
- "wild bears",
83
- 5,
84
- 3,
85
- ]
86
- ],
87
- allow_flagging="never",
88
- )
89
-
90
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # In[]:
2
+ import sys
3
+ import os
4
  import cv2
5
  import gradio as gr
6
  from PIL import Image
7
+ import numpy as np
8
+
9
+ from torch.nn.utils.rnn import pad_sequence
10
  from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
11
 
12
+ from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC
13
+
14
+ import pickle
15
+ from tqdm import tqdm
16
+ from PIL import Image
17
+
18
+ import torch
19
+ import re
20
+ import urllib.parse
21
+ import faiss
22
+
23
+ import webvtt
24
+ import json
25
+
26
+ from pytube import YouTube
27
+ from youtube_transcript_api import YouTubeTranscriptApi
28
+ from youtube_transcript_api.formatters import WebVTTFormatter
29
+
30
+ device = 'cpu'
31
+ model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc'
32
+ model = BridgeTowerForITC.from_pretrained(model_name).to(device)
33
+ text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device)
34
+
35
+ processor = BridgeTowerProcessor.from_pretrained(model_name)
36
+
37
+
38
+ def download_video(video_url, path='/tmp/'):
39
+
40
+ yt = YouTube(video_url)
41
+ yt = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
42
+ if not os.path.exists(path):
43
+ os.makedirs(path)
44
+ filepath = os.path.join(path, yt.default_filename)
45
+ if not os.path.exists(filepath):
46
+ print('Downloading video from YouTube...')
47
+ yt.download(path)
48
+ return filepath
49
+
50
+
51
+ # Get transcript in webvtt
52
+ def get_transcript_vtt(video_id, path='/tmp'):
53
+ filepath = os.path.join(path,'test_vm.vtt')
54
+ if os.path.exists(filepath):
55
+ return filepath
56
+
57
+ transcript = YouTubeTranscriptApi.get_transcript(video_id)
58
+ formatter = WebVTTFormatter()
59
+ webvtt_formatted = formatter.format_transcript(transcript)
60
+
61
+ with open(filepath, 'w', encoding='utf-8') as webvtt_file:
62
+ webvtt_file.write(webvtt_formatted)
63
+ webvtt_file.close()
64
+
65
+ return filepath
66
+
67
+ # https://stackoverflow.com/a/57781047
68
+ # Resizes a image and maintains aspect ratio
69
+ def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
70
+ # Grab the image size and initialize dimensions
71
+ dim = None
72
+ (h, w) = image.shape[:2]
73
+
74
+ # Return original image if no need to resize
75
+ if width is None and height is None:
76
+ return image
77
+
78
+ # We are resizing height if width is none
79
+ if width is None:
80
+ # Calculate the ratio of the height and construct the dimensions
81
+ r = height / float(h)
82
+ dim = (int(w * r), height)
83
+ # We are resizing width if height is none
84
+ else:
85
+ # Calculate the ratio of the width and construct the dimensions
86
+ r = width / float(w)
87
+ dim = (width, int(h * r))
88
+
89
+ # Return the resized image
90
+ return cv2.resize(image, dim, interpolation=inter)
91
+
92
+ def time_to_frame(time, fps):
93
+ '''
94
+ convert time in seconds into frame number
95
+ '''
96
+ return time * fps - 1
97
+
98
+ def str2time(strtime):
99
+ strtime = strtime.strip('"')
100
+ hrs, mins, seconds = [float(c) for c in strtime.split(':')]
101
+
102
+ total_seconds = hrs * 60**2 + mins * 60 + seconds
103
+
104
+ return total_seconds
105
+
106
+ def collate_fn(batch_list):
107
+ batch = {}
108
+ batch['input_ids'] = pad_sequence([encoding['input_ids'].squeeze(0) for encoding in batch_list], batch_first=True)
109
+ batch['attention_mask'] = pad_sequence([encoding['attention_mask'].squeeze(0) for encoding in batch_list], batch_first=True)
110
+ batch['pixel_values'] = torch.cat([encoding['pixel_values'] for encoding in batch_list], dim=0)
111
+ batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
112
+ return batch
113
+
114
+ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2):
115
+ if os.path.exists(os.path.join(output, 'embeddings.pkl')):
116
+ return
117
+
118
+ os.makedirs(output, exist_ok=True)
119
+ os.makedirs(os.path.join(output, 'frames'), exist_ok=True)
120
+ os.makedirs(os.path.join(output, 'frames_thumb'), exist_ok=True)
121
+
122
+ count = 0
123
+
124
+ vidcap = cv2.VideoCapture(video_path)
125
+
126
+ # Get the frames per second
127
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
128
+
129
+ # Get the total numer of frames in the video.
130
+ frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
131
+
132
+ print(fps, frame_count)
133
+
134
+ frame_number = 0
135
+
136
+ count = 0
137
+ anno = []
138
+
139
+ embeddings = []
140
+ batch_list = []
141
+
142
+ for idx, caption in enumerate(webvtt.read(subtitles)):
143
+ st_time = str2time(caption.start)
144
+ ed_time = str2time(caption.end)
145
+
146
+ mid_time = (ed_time + st_time) / 2
147
+ text = caption.text.replace('\n', ' ')
148
+
149
+ if expanded :
150
+ raise NotImplementedError
151
+
152
+ frame_no = time_to_frame(mid_time, fps)
153
+
154
+ print('Read a new frame: ', idx, mid_time, frame_no, text)
155
+ vidcap.set(1, frame_no) # added this line
156
+ success, image = vidcap.read()
157
+ if success:
158
+ img_fname = f'{video_id}_{idx:06d}'
159
+ img_fpath = os.path.join(output, 'frames', img_fname + '.jpg')
160
+ image = maintain_aspect_ratio_resize(image, height=350) # save frame as JPEG file
161
+ cv2.imwrite( img_fpath, image) # save frame as JPEG file
162
+
163
+ count += 1
164
+ anno.append({
165
+ 'image_id': idx,
166
+ 'img_fname': img_fname,
167
+ 'caption': text,
168
+ 'time': mid_time,
169
+ 'frame_no': frame_no
170
+ })
171
+
172
+ else:
173
  break
174
+
175
+ encoding = processor(image, text, return_tensors="pt").to(device)
176
+ encoding['text'] = text
177
+ encoding['image_filepath'] = img_fpath
178
+ encoding['start_time'] = caption.start
179
+
180
+ batch_list.append(encoding)
181
+
182
+ if len(batch_list) == batch_size:
183
+ batch = collate_fn(batch_list)
184
+ with torch.no_grad():
185
+ outputs = model(**batch, output_hidden_states=True)
186
+
187
+ for i in range(batch_size):
188
+ embeddings.append({
189
+ 'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(),
190
+ 'text': batch_list[i]['text'],
191
+ 'image_filepath': batch_list[i]['image_filepath'],
192
+ 'start_time': batch_list[i]['start_time'],
193
+ })
194
+ batch_list = []
195
+
196
+ if batch_list:
197
+ batch = collate_fn(batch_list)
198
+ with torch.no_grad():
199
+ outputs = model(**batch, output_hidden_states=True)
200
+
201
+ for i in range(len(batch_list)):
202
+ embeddings.append({
203
+ 'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(),
204
+ 'text': batch_list[i]['text'],
205
+ 'image_filepath': batch_list[i]['image_filepath'],
206
+ 'start_time': batch_list[i]['start_time'],
207
+ })
208
+
209
+ with open(os.path.join(output, 'annotations.json'), 'w') as fh:
210
+ json.dump(anno, fh)
211
+
212
+ with open(os.path.join(output, 'embeddings.pkl'), 'wb') as fh:
213
+ pickle.dump(embeddings, fh)
214
+
215
+ def run_query(video_id, text_query, path='/tmp'):
216
+
217
+ embeddings_filepath = os.path.join(path, 'embeddings.pkl')
218
+ faiss_filepath = os.path.join(path, 'faiss_index.pkl')
219
+
220
+ embeddings = pickle.load(open(embeddings_filepath, 'rb'))
221
+
222
+ if os.path.exists(faiss_filepath):
223
+ faiss_index = pickle.load(open(faiss_filepath, 'rb'))
224
+ else :
225
+ embs = [emb['embeddings'] for emb in embeddings]
226
+ vectors = np.stack(embs, axis=0)
227
+ num_vectors, vector_dim = vectors.shape
228
+ faiss_index = faiss.IndexFlatIP(vector_dim)
229
+ faiss_index.add(vectors)
230
+ pickle.dump(faiss_index, open(faiss_filepath, 'wb'))
231
+
232
+ print('Processing query')
233
+ encoding = processor.tokenizer(text_query, return_tensors="pt").to(device)
234
+ with torch.no_grad():
235
+ outputs = text_model(**encoding)
236
+ emb_query = outputs.cpu().numpy()
237
+ print('Running FAISS search')
238
+ _, I = faiss_index.search(emb_query, 6)
239
+
240
+ clip_images = [embeddings[idx]['image_filepath'] for idx in I[0]]
241
+ transcripts = [f"({embeddings[idx]['start_time']}) {embeddings[idx]['text']}" for idx in I[0]]
242
+ return clip_images, transcripts
243
+
244
+
245
+ def get_video_id_from_url(video_url):
246
+ """
247
+ Examples:
248
+ - http://youtu.be/SA2iWivDJiE
249
+ - http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
250
+ - http://www.youtube.com/embed/SA2iWivDJiE
251
+ - http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US
252
+ """
253
+ import urllib.parse
254
+ url = urllib.parse.urlparse(video_url)
255
+ if url.hostname == 'youtu.be':
256
+ return url.path[1:]
257
+ if url.hostname in ('www.youtube.com', 'youtube.com'):
258
+ if url.path == '/watch':
259
+ p = urllib.parse.parse_qs(url.query)
260
+ return p['v'][0]
261
+ if url.path[:7] == '/embed/':
262
+ return url.path.split('/')[2]
263
+ if url.path[:3] == '/v/':
264
+ return url.path.split('/')[2]
265
+
266
+
267
+ return None
268
+
269
+
270
+ def process(video_url, text_query):
271
+ tmp_dir = os.path.join(os.getcwd(), 'cache')
272
+ video_id = get_video_id_from_url(video_url)
273
+ output_dir = os.path.join(tmp_dir, video_id)
274
+ video_file = download_video(video_url, path=output_dir)
275
+ subtitles = get_transcript_vtt(video_id, path=output_dir)
276
+ extract_images_and_embeds(video_id=video_id,
277
+ video_path=video_file,
278
+ subtitles=subtitles,
279
+ output=output_dir,
280
+ expanded=False,
281
+ batch_size=8,
282
+ )
283
+ frame_paths, transcripts = run_query(video_id, text_query, path=output_dir)
284
+ return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  description = "This Space lets you run semantic search on a video."
288
 
289
+ with gr.Blocks() as demo:
290
+ gr.Markdown(description)
291
+ with gr.Row():
292
+ with gr.Column():
293
+ video_url = gr.Text(label="Youtube url")
294
+ text_query = gr.Text(label="Text query")
295
+ btn = gr.Button("Run query")
296
+ video_player = gr.Video(label="Video")
297
+
298
+ with gr.Row():
299
+ gallery = gr.Gallery(label="Images").style(grid=6)
300
+
301
+ gr.Examples(
302
+ examples=[
303
+ ['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
304
+ ['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake on floor'],
305
+ ['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'cat woman'],
306
+ ['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
307
+ ],
308
+ inputs=[video_url, text_query],
309
+ )
310
+
311
+ btn.click(fn=process,
312
+ inputs=[video_url, text_query],
313
+ outputs=[video_player, gallery],
314
+ )
315
+
316
+ demo.launch(share=True, server_port=25566)
bridgetower_custom.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from torchvision import transforms
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+
11
+ from transformers.modeling_outputs import SequenceClassifierOutput
12
+
13
+ from transformers import BridgeTowerPreTrainedModel, BridgeTowerModel
14
+ from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerTextModel
15
+
16
+ class LayerNorm(nn.LayerNorm):
17
+ """Subclass torch's LayerNorm to handle fp16."""
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ orig_type = x.dtype
21
+ ret = super().forward(x.type(torch.float32))
22
+ return ret.type(orig_type)
23
+
24
+ class BridgeTowerImageFeatureExtractor(nn.Module):
25
+ def __init__(
26
+ self,
27
+ patch_size=14,
28
+ width=1024,
29
+ resolution_after=294,
30
+ ckpt_path=None,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
35
+
36
+ scale = width ** -0.5
37
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
38
+ self.positional_embedding = nn.Parameter(scale * torch.randn((resolution_after // patch_size) ** 2 + 1, width))
39
+ self.ln_pre = LayerNorm(width)
40
+
41
+ if ckpt_path is not None:
42
+ sd = torch.load(ckpt_path)
43
+ if 'state_dict' in sd:
44
+ sd = sd["state_dict"]
45
+ print(f'Loading feature extractor checkpoint from {ckpt_path}')
46
+ self.load_state_dict(sd)
47
+
48
+ def forward(self, x: torch.Tensor):
49
+ x = self.conv1(x) # shape = [*, width, grid, grid]
50
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
51
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
52
+ t=self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
53
+ x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width]
54
+ x = x + self.positional_embedding.to(x.dtype)
55
+ x = self.ln_pre(x)
56
+ x = x.permute(1, 0, 2) # NLD -> LND
57
+ return x
58
+
59
+
60
+ class BridgeTowerITCHead(nn.Module):
61
+ def __init__(self, hidden_size, embed_size):
62
+ super().__init__()
63
+ self.fc = nn.Linear(hidden_size, embed_size)
64
+
65
+ def forward(self, x):
66
+ x = self.fc(x)
67
+ return x
68
+
69
+
70
+ class _BridgeTowerTextModelWrapper(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.text_model = BridgeTowerTextModel(config)
74
+
75
+ def forward(self, **kwargs):
76
+ return self.text_model(**kwargs)
77
+
78
+
79
+ class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel):
80
+ def __init__(self, config):
81
+ super().__init__(config)
82
+
83
+ self.bridgetower = _BridgeTowerTextModelWrapper(config.text_config)
84
+ self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
85
+
86
+ def forward(
87
+ self,
88
+ input_ids: Optional[torch.LongTensor] = None,
89
+ attention_mask: Optional[torch.FloatTensor] = None,
90
+ token_type_ids: Optional[torch.LongTensor] = None,
91
+ head_mask: Optional[torch.FloatTensor] = None,
92
+ inputs_embeds: Optional[torch.FloatTensor] = None,
93
+ output_attentions: Optional[bool] = None,
94
+ output_hidden_states: Optional[bool] = None,
95
+ return_dict: Optional[bool] = None,
96
+ labels: Optional[torch.LongTensor] = None,
97
+ ):
98
+
99
+ outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask)
100
+ final_hidden_cls = outputs.last_hidden_state[:,0,:]
101
+ final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
102
+
103
+ return final_hidden_cls
104
+
105
+
106
+ class BridgeTowerForITC(BridgeTowerPreTrainedModel):
107
+ def __init__(self, config):
108
+ super().__init__(config)
109
+
110
+ self.bridgetower = BridgeTowerModel(config)
111
+
112
+ self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
113
+ self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
114
+ self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size)
115
+
116
+ # Initialize weights and apply final processing
117
+ self.post_init()
118
+
119
+ def forward(
120
+ self,
121
+ input_ids: Optional[torch.LongTensor] = None,
122
+ attention_mask: Optional[torch.FloatTensor] = None,
123
+ token_type_ids: Optional[torch.LongTensor] = None,
124
+ pixel_values: Optional[torch.FloatTensor] = None,
125
+ pixel_mask: Optional[torch.LongTensor] = None,
126
+ head_mask: Optional[torch.FloatTensor] = None,
127
+ inputs_embeds: Optional[torch.FloatTensor] = None,
128
+ image_embeds: Optional[torch.FloatTensor] = None,
129
+ output_attentions: Optional[bool] = None,
130
+ output_hidden_states: Optional[bool] = None,
131
+ return_dict: Optional[bool] = None,
132
+ labels: Optional[torch.LongTensor] = None,
133
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
134
+
135
+ assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC'
136
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
137
+
138
+ outputs = self.bridgetower(
139
+ input_ids,
140
+ attention_mask=attention_mask,
141
+ token_type_ids=token_type_ids,
142
+ pixel_values=pixel_values,
143
+ pixel_mask=pixel_mask,
144
+ head_mask=head_mask,
145
+ inputs_embeds=inputs_embeds,
146
+ image_embeds=image_embeds,
147
+ output_attentions=output_attentions,
148
+ output_hidden_states=output_hidden_states,
149
+ return_dict=return_dict,
150
+ )
151
+
152
+ pooler_output = outputs.pooler_output if return_dict else outputs[2]
153
+
154
+ hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states
155
+
156
+ final_hidden_txt = hidden_states_txt[-1]
157
+ final_hidden_img = hidden_states_img[-1]
158
+
159
+ image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(final_hidden_img)
160
+ image_token_type_embeddings = self.bridgetower.token_type_embeddings(
161
+ torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
162
+ ).expand_as(image_embeds_with_ln)
163
+
164
+ final_hidden_img = (
165
+ self.bridgetower.cross_modal_image_transform(image_embeds_with_ln)
166
+ + image_token_type_embeddings
167
+ )
168
+
169
+ final_hidden_txt = F.normalize(self.itc_text_head(final_hidden_txt[:,0,:]), dim=-1, p=2)
170
+ final_hidden_img = F.normalize(self.itc_image_head(final_hidden_img[:,0,:]), dim=-1, p=2)
171
+ final_hidden_cross = F.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
172
+
173
+ logits = torch.stack([final_hidden_txt, final_hidden_img, final_hidden_cross], dim=-2)
174
+
175
+ if not return_dict:
176
+ return tuple(logits)
177
+
178
+ return SequenceClassifierOutput(
179
+ loss=None,
180
+ logits=logits,
181
+ hidden_states=outputs.hidden_states,
182
+ attentions=outputs.attentions,
183
+ )
requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
  git+https://github.com/huggingface/transformers
2
  torch
3
  requests
4
- Pillow
 
 
 
 
 
1
  git+https://github.com/huggingface/transformers
2
  torch
3
  requests
4
+ Pillow
5
+ youtube-transcript-api
6
+ faiss-cpu
7
+ webvtt
8
+ pytube