Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import json | |
import pickle | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from transformers import BridgeTowerProcessor | |
from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC | |
import faiss | |
import webvtt | |
from pytube import YouTube | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from youtube_transcript_api.formatters import WebVTTFormatter | |
device = 'cpu' | |
model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc' | |
model = BridgeTowerForITC.from_pretrained(model_name).to(device) | |
text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device) | |
processor = BridgeTowerProcessor.from_pretrained(model_name) | |
def download_video(video_url, path='/tmp/'): | |
yt = YouTube(video_url) | |
yt = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() | |
if not os.path.exists(path): | |
os.makedirs(path) | |
filepath = os.path.join(path, yt.default_filename) | |
if not os.path.exists(filepath): | |
print('Downloading video from YouTube...') | |
yt.download(path) | |
return filepath | |
# Get transcript in webvtt | |
def get_transcript_vtt(video_id, path='/tmp'): | |
filepath = os.path.join(path,'test_vm.vtt') | |
if os.path.exists(filepath): | |
return filepath | |
transcript = YouTubeTranscriptApi.get_transcript(video_id) | |
formatter = WebVTTFormatter() | |
webvtt_formatted = formatter.format_transcript(transcript) | |
with open(filepath, 'w', encoding='utf-8') as webvtt_file: | |
webvtt_file.write(webvtt_formatted) | |
webvtt_file.close() | |
return filepath | |
# https://stackoverflow.com/a/57781047 | |
# Resizes a image and maintains aspect ratio | |
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA): | |
# Grab the image size and initialize dimensions | |
dim = None | |
(h, w) = image.shape[:2] | |
# Return original image if no need to resize | |
if width is None and height is None: | |
return image | |
# We are resizing height if width is none | |
if width is None: | |
# Calculate the ratio of the height and construct the dimensions | |
r = height / float(h) | |
dim = (int(w * r), height) | |
# We are resizing width if height is none | |
else: | |
# Calculate the ratio of the width and construct the dimensions | |
r = width / float(w) | |
dim = (width, int(h * r)) | |
# Return the resized image | |
return cv2.resize(image, dim, interpolation=inter) | |
def time_to_frame(time, fps): | |
''' | |
convert time in seconds into frame number | |
''' | |
return time * fps - 1 | |
def str2time(strtime): | |
strtime = strtime.strip('"') | |
hrs, mins, seconds = [float(c) for c in strtime.split(':')] | |
total_seconds = hrs * 60**2 + mins * 60 + seconds | |
return total_seconds | |
def collate_fn(batch_list): | |
batch = {} | |
batch['input_ids'] = pad_sequence([encoding['input_ids'].squeeze(0) for encoding in batch_list], batch_first=True) | |
batch['attention_mask'] = pad_sequence([encoding['attention_mask'].squeeze(0) for encoding in batch_list], batch_first=True) | |
batch['pixel_values'] = torch.cat([encoding['pixel_values'] for encoding in batch_list], dim=0) | |
batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0) | |
return batch | |
def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2): | |
if os.path.exists(os.path.join(output, 'embeddings.pkl')): | |
return | |
os.makedirs(output, exist_ok=True) | |
os.makedirs(os.path.join(output, 'frames'), exist_ok=True) | |
os.makedirs(os.path.join(output, 'frames_thumb'), exist_ok=True) | |
count = 0 | |
vidcap = cv2.VideoCapture(video_path) | |
# Get the frames per second | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
# Get the total numer of frames in the video. | |
frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT) | |
print(fps, frame_count) | |
frame_number = 0 | |
count = 0 | |
anno = [] | |
embeddings = [] | |
batch_list = [] | |
for idx, caption in enumerate(webvtt.read(subtitles)): | |
st_time = str2time(caption.start) | |
ed_time = str2time(caption.end) | |
mid_time = (ed_time + st_time) / 2 | |
text = caption.text.replace('\n', ' ') | |
if expanded : | |
raise NotImplementedError | |
frame_no = time_to_frame(mid_time, fps) | |
print('Read a new frame: ', idx, mid_time, frame_no, text) | |
vidcap.set(1, frame_no) # added this line | |
success, image = vidcap.read() | |
if success: | |
img_fname = f'{video_id}_{idx:06d}' | |
img_fpath = os.path.join(output, 'frames', img_fname + '.jpg') | |
# image = maintain_aspect_ratio_resize(image, height=350) # save frame as JPEG file | |
# cv2.imwrite( img_fpath, image) # save frame as JPEG file | |
count += 1 | |
anno.append({ | |
'image_id': idx, | |
'img_fname': img_fname, | |
'caption': text, | |
'time': mid_time, | |
'frame_no': frame_no | |
}) | |
else: | |
break | |
encoding = processor(image, text, return_tensors="pt").to(device) | |
encoding['text'] = text | |
encoding['image_filepath'] = img_fpath | |
encoding['start_time'] = caption.start | |
batch_list.append(encoding) | |
if len(batch_list) == batch_size: | |
batch = collate_fn(batch_list) | |
with torch.no_grad(): | |
outputs = model(**batch, output_hidden_states=True) | |
for i in range(batch_size): | |
embeddings.append({ | |
'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(), | |
'text': batch_list[i]['text'], | |
'image_filepath': batch_list[i]['image_filepath'], | |
'start_time': batch_list[i]['start_time'], | |
'frame_no': frame_no, | |
}) | |
batch_list = [] | |
if batch_list: | |
batch = collate_fn(batch_list) | |
with torch.no_grad(): | |
outputs = model(**batch, output_hidden_states=True) | |
for i in range(len(batch_list)): | |
embeddings.append({ | |
'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(), | |
'text': batch_list[i]['text'], | |
'image_filepath': batch_list[i]['image_filepath'], | |
'start_time': batch_list[i]['start_time'], | |
'frame_no': frame_no, | |
}) | |
with open(os.path.join(output, 'annotations.json'), 'w') as fh: | |
json.dump(anno, fh) | |
with open(os.path.join(output, 'embeddings.pkl'), 'wb') as fh: | |
pickle.dump(embeddings, fh) | |
def run_query(video_path, text_query, path='/tmp'): | |
vidcap = cv2.VideoCapture(video_path) | |
embeddings_filepath = os.path.join(path, 'embeddings.pkl') | |
faiss_filepath = os.path.join(path, 'faiss_index.pkl') | |
embeddings = pickle.load(open(embeddings_filepath, 'rb')) | |
if os.path.exists(faiss_filepath): | |
faiss_index = pickle.load(open(faiss_filepath, 'rb')) | |
else : | |
embs = [emb['embeddings'] for emb in embeddings] | |
vectors = np.stack(embs, axis=0) | |
num_vectors, vector_dim = vectors.shape | |
faiss_index = faiss.IndexFlatIP(vector_dim) | |
faiss_index.add(vectors) | |
pickle.dump(faiss_index, open(faiss_filepath, 'wb')) | |
print('Processing query') | |
encoding = processor.tokenizer(text_query, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = text_model(**encoding) | |
emb_query = outputs.cpu().numpy() | |
print('Running FAISS search') | |
_, I = faiss_index.search(emb_query, 6) | |
clip_images = [] | |
for idx in I[0]: | |
frame_no = embeddings[idx]['frame_no'] | |
vidcap.set(1, frame_no) # added this line | |
success, image = vidcap.read() | |
clip_images.append(image) | |
# clip_images = [embeddings[idx]['image_filepath'] for idx in I[0]] | |
transcripts = [f"({embeddings[idx]['start_time']}) {embeddings[idx]['text']}" for idx in I[0]] | |
return clip_images, transcripts | |
def get_video_id_from_url(video_url): | |
""" | |
Examples: | |
- http://youtu.be/SA2iWivDJiE | |
- http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu | |
- http://www.youtube.com/embed/SA2iWivDJiE | |
- http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US | |
""" | |
import urllib.parse | |
url = urllib.parse.urlparse(video_url) | |
if url.hostname == 'youtu.be': | |
return url.path[1:] | |
if url.hostname in ('www.youtube.com', 'youtube.com'): | |
if url.path == '/watch': | |
p = urllib.parse.parse_qs(url.query) | |
return p['v'][0] | |
if url.path[:7] == '/embed/': | |
return url.path.split('/')[2] | |
if url.path[:3] == '/v/': | |
return url.path.split('/')[2] | |
return None | |
def process(video_url, text_query): | |
tmp_dir = os.environ.get('TMPDIR', '/tmp') | |
video_id = get_video_id_from_url(video_url) | |
output_dir = os.path.join(tmp_dir, video_id) | |
video_file = download_video(video_url, path=output_dir) | |
subtitles = get_transcript_vtt(video_id, path=output_dir) | |
extract_images_and_embeds(video_id=video_id, | |
video_path=video_file, | |
subtitles=subtitles, | |
output=output_dir, | |
expanded=False, | |
batch_size=8, | |
) | |
frame_paths, transcripts = run_query(video_id, text_query, path=output_dir) | |
return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)] | |
description = "This Space lets you run semantic search on a video." | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
video_url = gr.Text(label="Youtube url") | |
text_query = gr.Text(label="Text query") | |
btn = gr.Button("Run query") | |
video_player = gr.Video(label="Video") | |
with gr.Row(): | |
gallery = gr.Gallery(label="Images").style(grid=6) | |
gr.Examples( | |
examples=[ | |
['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'], | |
['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake on floor'], | |
['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'cat woman'], | |
['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'], | |
], | |
inputs=[video_url, text_query], | |
) | |
btn.click(fn=process, | |
inputs=[video_url, text_query], | |
outputs=[video_player, gallery], | |
) | |
demo.launch(share=True) | |