Thao Pham
commited on
Commit
·
d50ce1c
1
Parent(s):
1d5379f
First commit
Browse files- .gitignore +5 -0
- app.py +213 -0
- embed.py +62 -0
- rag.py +270 -0
- requirements.txt +17 -0
- utils.py +91 -0
- video_utils.py +156 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
.env
|
3 |
+
__pycache__
|
4 |
+
tmp
|
5 |
+
uploads
|
app.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import re
|
4 |
+
import video_utils
|
5 |
+
import utils
|
6 |
+
import embed
|
7 |
+
import rag
|
8 |
+
import os
|
9 |
+
import uuid
|
10 |
+
import numpy as np
|
11 |
+
from pinecone import Pinecone, ServerlessSpec
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
+
from transformers import AutoImageProcessor, AutoModel
|
14 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
|
17 |
+
load_dotenv() # Load from .env
|
18 |
+
|
19 |
+
UPLOAD_FOLDER = 'uploads'
|
20 |
+
video_name = None
|
21 |
+
|
22 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
23 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
24 |
+
|
25 |
+
# init models
|
26 |
+
TEXT_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
27 |
+
VISION_MODEL_PROCESSOR = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
|
28 |
+
VISION_MODEL = AutoModel.from_pretrained('facebook/dinov2-small')
|
29 |
+
|
30 |
+
VLM_PROCESSOR = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
31 |
+
VLM = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
32 |
+
|
33 |
+
# init index
|
34 |
+
pc = Pinecone(
|
35 |
+
api_key=PINECONE_API_KEY
|
36 |
+
)
|
37 |
+
# Connect to an index
|
38 |
+
index_name = "mutlimodal-minilm"
|
39 |
+
INDEX = pc.Index(index_name)
|
40 |
+
MODEL_STACK = [TEXT_MODEL, VISION_MODEL, VISION_MODEL_PROCESSOR, VLM, VLM_PROCESSOR]
|
41 |
+
|
42 |
+
|
43 |
+
def is_valid_youtube_url(url):
|
44 |
+
"""
|
45 |
+
Checks if the given URL is a valid YouTube video URL.
|
46 |
+
|
47 |
+
Returns True if valid, False otherwise.
|
48 |
+
"""
|
49 |
+
youtube_regex = re.compile(
|
50 |
+
r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/"
|
51 |
+
r"(watch\?v=|embed/|v/|shorts/)?([a-zA-Z0-9_-]{11})"
|
52 |
+
)
|
53 |
+
|
54 |
+
match = youtube_regex.match(url)
|
55 |
+
return bool(match)
|
56 |
+
|
57 |
+
|
58 |
+
def check_exist_before_upsert(index, video_path):
|
59 |
+
# threshold = len(frames) * 3
|
60 |
+
threshold = [elem for elem in os.listdir(video_path.split('/')[0]) if elem.endswith('.jpg')]
|
61 |
+
threshold = len(threshold)*3 # image embeds, caption embeds, transcript embeds
|
62 |
+
|
63 |
+
dimension = 384
|
64 |
+
res = index.query(
|
65 |
+
vector=[0]*dimension, # Dummy vector (not used for filtering)
|
66 |
+
top_k=10000, # Set a high value to retrieve as many matches as possible
|
67 |
+
filter={"video_path": video_path} # Filter by video_path
|
68 |
+
)
|
69 |
+
|
70 |
+
# Count the number of matching vectors
|
71 |
+
num_existing_vectors = len(res["matches"])
|
72 |
+
|
73 |
+
if num_existing_vectors >= threshold:
|
74 |
+
return True
|
75 |
+
return False
|
76 |
+
|
77 |
+
|
78 |
+
def chat(message, history):
|
79 |
+
image_input_path = None
|
80 |
+
if len(message['files']) > 0:
|
81 |
+
assert len(message['files']) == 1
|
82 |
+
image_input_path = message['files'][0]
|
83 |
+
|
84 |
+
message = message['text']
|
85 |
+
|
86 |
+
if history is None:
|
87 |
+
history = []
|
88 |
+
|
89 |
+
if message.startswith("https://"):
|
90 |
+
# Check valid URL
|
91 |
+
history.append((message, f"Checking if your provided URL at {message} is valid..."))
|
92 |
+
yield history
|
93 |
+
|
94 |
+
valid = is_valid_youtube_url(message)
|
95 |
+
if not valid:
|
96 |
+
history.append((None, "❌ Invalid YouTube URL. Please try again."))
|
97 |
+
yield history
|
98 |
+
return
|
99 |
+
|
100 |
+
# Check metadata
|
101 |
+
history.append((None, "✅ URL is valid! Fetching video metadata..."))
|
102 |
+
yield history
|
103 |
+
|
104 |
+
video_metadata = video_utils.get_video_metdata(message)
|
105 |
+
history.append((None, f"The video you want to process is: \nTitle: {video_metadata['title']} published by {video_metadata['author']} on {video_metadata['publish_date']}."))
|
106 |
+
yield history
|
107 |
+
|
108 |
+
history.append((None, "⏳ Downloading video..."))
|
109 |
+
yield history
|
110 |
+
|
111 |
+
output_folder_path = os.path.join(UPLOAD_FOLDER, video_metadata['title'])
|
112 |
+
path_to_video = os.path.join(output_folder_path, f"video.mp4")
|
113 |
+
if not os.path.exists(path_to_video):
|
114 |
+
path_to_video = utils.download_video(message, path=output_folder_path)
|
115 |
+
|
116 |
+
history.append((None, "⏳ Transcribing video..."))
|
117 |
+
yield history
|
118 |
+
path_to_audio_file = os.path.join(output_folder_path, f"audio.mp3")
|
119 |
+
if not os.path.exists(path_to_audio_file):
|
120 |
+
path_to_audio_file = video_utils.extract_audio(path_to_video, output_folder_path)
|
121 |
+
|
122 |
+
path_to_generated_transcript = os.path.join(output_folder_path, f"transcript.vtt")
|
123 |
+
if not os.path.exists(path_to_generated_transcript):
|
124 |
+
path_to_generated_transcript = video_utils.transcribe_video(path_to_audio_file, output_folder_path)
|
125 |
+
|
126 |
+
# extract frames and metadata
|
127 |
+
metadatas_path = os.path.join(output_folder_path, 'metadatas.json')
|
128 |
+
if not os.path.exists(metadatas_path):
|
129 |
+
metadatas = video_utils.extract_and_save_frames_and_metadata(path_to_video=path_to_video,
|
130 |
+
path_to_transcript=path_to_generated_transcript,
|
131 |
+
path_to_save_extracted_frames=output_folder_path,
|
132 |
+
path_to_save_metadatas=output_folder_path)
|
133 |
+
|
134 |
+
history.append((None, "⏳ Captioning video..."))
|
135 |
+
yield history
|
136 |
+
|
137 |
+
caption_path = os.path.join(output_folder_path, 'captions.json')
|
138 |
+
if not os.path.exists(caption_path):
|
139 |
+
video_frames = [os.path.join(output_folder_path, elem) for elem in os.listdir(output_folder_path) if elem.endswith('.jpg')]
|
140 |
+
metadata_path = video_utils.get_video_caption(video_frames, metadatas, output_folder_path, vlm=VLM, vlm_processor=VLM_PROCESSOR)
|
141 |
+
|
142 |
+
history.append((None, "⏳ Indexing..."))
|
143 |
+
yield history
|
144 |
+
index_exist = check_exist_before_upsert(INDEX, path_to_video)
|
145 |
+
print(index_exist)
|
146 |
+
if not index_exist:
|
147 |
+
embed.indexing(INDEX, MODEL_STACK, metadatas_path)
|
148 |
+
|
149 |
+
# summarizing video
|
150 |
+
video_summary = rag.summarize_video(metadatas_path)
|
151 |
+
with open(os.path.join(output_folder_path, "summary.txt"), "w") as f:
|
152 |
+
f.write(video_summary)
|
153 |
+
|
154 |
+
history.append((None, f"Video processing complete! You can now ask me questions about the video {video_metadata['title']}!"))
|
155 |
+
yield history
|
156 |
+
|
157 |
+
global video_name
|
158 |
+
video_name = video_metadata['title']
|
159 |
+
else:
|
160 |
+
history.append((message, None))
|
161 |
+
yield history
|
162 |
+
|
163 |
+
if video_name is None:
|
164 |
+
history.append((None, "You need to insert video URL before asking questions."))
|
165 |
+
yield history
|
166 |
+
return
|
167 |
+
|
168 |
+
output_folder_path = f"{UPLOAD_FOLDER}/{video_name}"
|
169 |
+
metadatas_path = os.path.join(output_folder_path, 'metadatas.json')
|
170 |
+
|
171 |
+
video_summary = ''
|
172 |
+
with open(f'./{output_folder_path}/summary.txt') as f:
|
173 |
+
while True:
|
174 |
+
ln = f.readline()
|
175 |
+
if ln == '':
|
176 |
+
break
|
177 |
+
video_summary += ln.strip()
|
178 |
+
video_path = os.path.join(output_folder_path, 'video.mp4')
|
179 |
+
answer = rag.answer_question(INDEX, MODEL_STACK, metadatas_path, video_summary, video_path, message, image_input_path)
|
180 |
+
|
181 |
+
history.append((None, answer))
|
182 |
+
yield history
|
183 |
+
|
184 |
+
def clear_chat(history):
|
185 |
+
# return []
|
186 |
+
history = []
|
187 |
+
history.append((None, "Please input a Youtube URL to get started!"))
|
188 |
+
# yield history
|
189 |
+
return history
|
190 |
+
|
191 |
+
def main():
|
192 |
+
initial_messages = [(None, "Please input a Youtube URL to get started!")]
|
193 |
+
|
194 |
+
with gr.Blocks() as demo:
|
195 |
+
chatbot = gr.Chatbot(value=initial_messages)
|
196 |
+
msg = gr.MultimodalTextbox(file_types=['image'], sources=['upload'])
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Column():
|
200 |
+
submit = gr.Button("Send")
|
201 |
+
submit.click(chat, [msg, chatbot], chatbot)
|
202 |
+
|
203 |
+
with gr.Column():
|
204 |
+
clear = gr.Button("Clear") # Clear button
|
205 |
+
# Clear chat history when clear button is clicked
|
206 |
+
clear.click(clear_chat, [], chatbot)
|
207 |
+
global video_name
|
208 |
+
video_name = None
|
209 |
+
|
210 |
+
demo.launch()
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
main()
|
embed.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
from transformers import AutoImageProcessor, AutoModel
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
|
6 |
+
import numpy as np
|
7 |
+
import uuid
|
8 |
+
from utils import load_json_file
|
9 |
+
|
10 |
+
def embed_texts(text_ls:List[str], text_model=None):
|
11 |
+
if text_model is None:
|
12 |
+
text_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
13 |
+
|
14 |
+
text_embeddings = []
|
15 |
+
for i, text in enumerate(tqdm(text_ls, desc="Embedding text")):
|
16 |
+
embeds = text_model.encode(text)
|
17 |
+
text_embeddings.append(embeds)
|
18 |
+
return np.array(text_embeddings)
|
19 |
+
|
20 |
+
|
21 |
+
def embed_images(image_path_ls:List[str], vision_model=None, vision_model_processor=None):
|
22 |
+
if vision_model is None or vision_model_processor is None:
|
23 |
+
vision_model_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
|
24 |
+
vision_model = AutoModel.from_pretrained('facebook/dinov2-small')
|
25 |
+
|
26 |
+
image_embeds_ls = []
|
27 |
+
for i, frame in enumerate(tqdm(image_path_ls, desc="Embedding image")):
|
28 |
+
frame = Image.open(frame)
|
29 |
+
# TODO: add device here
|
30 |
+
inputs = vision_model_processor(images=frame, return_tensors="pt")
|
31 |
+
outputs = vision_model(**inputs)
|
32 |
+
image_embeds_ls.append(outputs.pooler_output)
|
33 |
+
return np.array([elem.squeeze().detach().numpy() for elem in image_embeds_ls])
|
34 |
+
|
35 |
+
|
36 |
+
def indexing(index, model_stack, vid_metadata_path):
|
37 |
+
text_model, vision_model, vision_model_processor, _, _ = model_stack
|
38 |
+
|
39 |
+
# read metadata file
|
40 |
+
vid_metadata = load_json_file(vid_metadata_path)
|
41 |
+
|
42 |
+
# embed transcripts
|
43 |
+
vid_trans = [frame['transcript'] for frame in vid_metadata]
|
44 |
+
transcript_embeddings = embed_texts(text_ls=vid_trans, text_model=text_model)
|
45 |
+
|
46 |
+
# embed caption
|
47 |
+
vid_captions = [frame['caption'] for frame in vid_metadata]
|
48 |
+
caption_embeddings = embed_texts(text_ls=vid_captions, text_model=text_model)
|
49 |
+
|
50 |
+
# embed frames
|
51 |
+
vid_img_paths = [vid['extracted_frame_path'] for vid in vid_metadata]
|
52 |
+
frame_embeddings = embed_images(vid_img_paths, vision_model, vision_model_processor)
|
53 |
+
|
54 |
+
for ls in [transcript_embeddings, caption_embeddings, frame_embeddings]:
|
55 |
+
# Prepare metadata
|
56 |
+
vectors = [
|
57 |
+
(str(uuid.uuid4()), emb.tolist(), meta) # Generate unique IDs
|
58 |
+
for emb, meta in zip(ls, vid_metadata)
|
59 |
+
]
|
60 |
+
# Upsert vectors into Pinecone
|
61 |
+
index.upsert(vectors)
|
62 |
+
|
rag.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from langchain_community.chat_models import ChatOpenAI
|
3 |
+
from langchain.chains import LLMChain
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
from utils import load_json_file, str2time
|
7 |
+
from openai import OpenAI
|
8 |
+
import base64
|
9 |
+
|
10 |
+
def get_smallest_timestamp(timestamps):
|
11 |
+
assert len(timestamps) > 0
|
12 |
+
|
13 |
+
timestamps_in_ms = [str2time(elem) for elem in timestamps]
|
14 |
+
|
15 |
+
smallest_timestamp_in_ms = timestamps_in_ms[0]
|
16 |
+
smallest_timestamp = timestamps[0]
|
17 |
+
for i, elem in enumerate(timestamps_in_ms):
|
18 |
+
if elem < smallest_timestamp_in_ms:
|
19 |
+
smallest_timestamp_in_ms = elem
|
20 |
+
smallest_timestamp = timestamps[i]
|
21 |
+
return smallest_timestamp
|
22 |
+
|
23 |
+
def generate(query, context, relevant_timestamps=None):
|
24 |
+
prompt = PromptTemplate(input_variables=["question", "context"], template="You're a helpful LLM assistant in answering questions regarding a video. Given contexts are segments relevant to the question, please answer the question. Do not refer to segments. Context: {context}, question: {question} \nA:")
|
25 |
+
|
26 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
27 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
28 |
+
response = chain.run(question=query, context=context)
|
29 |
+
|
30 |
+
if relevant_timestamps is not None and len(relevant_timestamps)>0:
|
31 |
+
# get smallest timestamp = earliest mention
|
32 |
+
smallest_timestamp = get_smallest_timestamp(relevant_timestamps)
|
33 |
+
response += f' {smallest_timestamp}'
|
34 |
+
return response
|
35 |
+
|
36 |
+
|
37 |
+
def check_relevance(query, relevant_metadatas):
|
38 |
+
transcripts = [frame['transcript'] for frame in relevant_metadatas]
|
39 |
+
captions = [frame['caption'] for frame in relevant_metadatas]
|
40 |
+
timestamps = [frame['start_time'] for frame in relevant_metadatas]
|
41 |
+
|
42 |
+
context = ""
|
43 |
+
for i in range(len(transcripts)):
|
44 |
+
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
|
45 |
+
# print(context)
|
46 |
+
|
47 |
+
prompt = PromptTemplate(input_variables=["question", "context"], template="""
|
48 |
+
You are a grader assessing relevance of a retrieved video segment to a user question. \n
|
49 |
+
If the video segment contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
|
50 |
+
Give a binary score 'yes' or 'no' score to indicate whether the video segment is relevant to the question. \n
|
51 |
+
Answer in a string, separated by commas. For example: if there are segments provided, answer: yes,no,no,yes. \n
|
52 |
+
Question: {question} Context: {context}\n A:""")
|
53 |
+
|
54 |
+
# query = "What are the books mentioned in the video?"
|
55 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
56 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
57 |
+
response = chain.run(question=query, context=context)
|
58 |
+
# print(response)
|
59 |
+
|
60 |
+
relevance_response = response.split(',')
|
61 |
+
|
62 |
+
actual_relevant_context = ""
|
63 |
+
relevant_timestamps = []
|
64 |
+
for i, relevance_check in enumerate(relevance_response):
|
65 |
+
if relevance_check.strip() == 'yes':
|
66 |
+
actual_relevant_context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
|
67 |
+
relevant_timestamps.append(timestamps[i])
|
68 |
+
return actual_relevant_context, relevant_timestamps
|
69 |
+
|
70 |
+
|
71 |
+
def retrieve_segments_from_timestamp(metadatas, timestamps):
|
72 |
+
relevant_segments = []
|
73 |
+
|
74 |
+
for timestamp in timestamps:
|
75 |
+
time_to_find_ms = str2time(timestamp)
|
76 |
+
buffer = 5000 # 5 seconds before and after
|
77 |
+
|
78 |
+
for segment in metadatas:
|
79 |
+
start = str2time(segment['start_time'])
|
80 |
+
end = str2time(segment['end_time'])
|
81 |
+
if start <= time_to_find_ms + buffer and end >= time_to_find_ms - buffer:
|
82 |
+
relevant_segments.append(segment)
|
83 |
+
|
84 |
+
return relevant_segments
|
85 |
+
|
86 |
+
|
87 |
+
def check_timestamps(query):
|
88 |
+
prompt = PromptTemplate(input_variables=["question"], template="You're a helpful LLM assistant. You're good at detecting any timestamps provided in a query. Please detect the question and timestamp in the the following question and separated them by commas such as question,timestamp1,timestamp2 if timestamps are provided else just question. Question: {question} \nA:")
|
89 |
+
|
90 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
91 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
92 |
+
response = chain.run(question=query)
|
93 |
+
|
94 |
+
timestamps = []
|
95 |
+
if len(response.split(',')) > 1:
|
96 |
+
query = response.split(',')[0].strip()
|
97 |
+
timestamps = [f"00:{elem.strip()}.00" for elem in response.split(',')[1:]]
|
98 |
+
|
99 |
+
return query, timestamps
|
100 |
+
|
101 |
+
def retrieve_by_embedding(index, video_path, query, text_model):
|
102 |
+
print(query)
|
103 |
+
query_embedding = text_model.encode(query)
|
104 |
+
|
105 |
+
res = index.query(vector=query_embedding.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
|
106 |
+
|
107 |
+
metadatas = []
|
108 |
+
for id, match_ in enumerate(res['matches']):
|
109 |
+
result = index.fetch(ids=[match_['id']])
|
110 |
+
|
111 |
+
# Extract the vector data
|
112 |
+
vector_data = result.vectors.get(match_['id'], {})
|
113 |
+
|
114 |
+
# Extract metadata
|
115 |
+
metadata = vector_data.metadata
|
116 |
+
metadatas.append(metadata)
|
117 |
+
|
118 |
+
return metadatas
|
119 |
+
|
120 |
+
def self_reflection(query, answer, summary):
|
121 |
+
prompt = PromptTemplate(input_variables=["summary", "question", "answer"], template="You're a helpful LLM assistant. You're good at determining if the provided answer is satisfactory to a question relating to a video. You have access to the video summary as follows: {summary}. Given a pair of question and answer, give the answer's satisfactory score in either yes or no. Question: {question}, Answer: {answer} \nA:")
|
122 |
+
|
123 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
124 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
125 |
+
response = chain.run(summary=summary, question=query, answer=answer)
|
126 |
+
return response
|
127 |
+
|
128 |
+
|
129 |
+
def get_full_transcript(metadatas):
|
130 |
+
# metadatas = webvtt.read(path_to_transcript)
|
131 |
+
transcripts = [frame['transcript'] for frame in metadatas]
|
132 |
+
|
133 |
+
full_text = ''
|
134 |
+
for idx, transcript in enumerate(transcripts):
|
135 |
+
text = transcript.strip().replace(" ", " ")
|
136 |
+
full_text += f"{text} "
|
137 |
+
|
138 |
+
full_text = full_text.strip()
|
139 |
+
return full_text
|
140 |
+
|
141 |
+
def summarize_video(metadatas_path:str):
|
142 |
+
metadatas = load_json_file(metadatas_path)
|
143 |
+
|
144 |
+
# get full transcript
|
145 |
+
transcript = get_full_transcript(metadatas)
|
146 |
+
prompt = PromptTemplate(input_variables=["transcript"], template="You're a helpful LLM assistant. Please provide a summary for the video given its full transcript: {transcript} \nA:")
|
147 |
+
|
148 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
149 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
150 |
+
response = chain.run(transcript=transcript)
|
151 |
+
return response
|
152 |
+
|
153 |
+
def answer_wrt_timestamp(query, context):
|
154 |
+
prompt = PromptTemplate(input_variables=["question", "context"], template="""
|
155 |
+
You're a helpful LLM assistant. Given a question and a timestamp, I have retrieved the relevant context as follows. Please answer the question using the information provided in the context. Question: {question}, context: {context} \n
|
156 |
+
For example: Question="What happens at 4:20?" Caption="a person is standing up" Transcript="I have to go" Appropriate Answer="At 4:20, a person is standing up and saying he has to go."
|
157 |
+
A:""")
|
158 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
159 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
160 |
+
response = chain.run(question=query, context=context)
|
161 |
+
return response
|
162 |
+
|
163 |
+
|
164 |
+
def answer_question(index, model_stack, metadatas_path, video_summary:str, video_path:str, query:str, image_input_path:str=None):
|
165 |
+
metadatas = load_json_file(metadatas_path)
|
166 |
+
if image_input_path is not None:
|
167 |
+
return answer_image_question(index, model_stack, metadatas, video_summary, video_path, query, image_input_path)
|
168 |
+
|
169 |
+
# check if timestamp provided
|
170 |
+
query, timestamps = check_timestamps(query)
|
171 |
+
|
172 |
+
if len(timestamps) > 0:
|
173 |
+
# retrieve by timestamps
|
174 |
+
relevant_segments_metadatas = retrieve_segments_from_timestamp(metadatas, timestamps)
|
175 |
+
transcripts = [frame['transcript'] for frame in relevant_segments_metadatas]
|
176 |
+
captions = [frame['caption'] for frame in relevant_segments_metadatas]
|
177 |
+
context = ""
|
178 |
+
for i in range(len(transcripts)):
|
179 |
+
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
|
180 |
+
# print(context)
|
181 |
+
return answer_wrt_timestamp(query, context)
|
182 |
+
else:
|
183 |
+
# retrieve by embedding
|
184 |
+
relevant_segments_metadatas = retrieve_by_embedding(index, video_path, query, model_stack[0])
|
185 |
+
|
186 |
+
# check relevance
|
187 |
+
actual_relevant_context, relevant_timestamps = check_relevance(query, relevant_segments_metadatas)
|
188 |
+
# relevant_timestamps = [frame['start_time'] for frame in relevant_segments_metadatas]
|
189 |
+
# print(actual_relevant_context)
|
190 |
+
|
191 |
+
# generate
|
192 |
+
answer = generate(query, actual_relevant_context, relevant_timestamps)
|
193 |
+
# print(answer)
|
194 |
+
|
195 |
+
# self-reflection
|
196 |
+
reflect = self_reflection(query, answer, video_summary)
|
197 |
+
|
198 |
+
# print("Reflect", reflect)
|
199 |
+
if reflect.lower() == 'no':
|
200 |
+
answer = generate(query, f"{actual_relevant_context}\nSummary={video_summary}")
|
201 |
+
|
202 |
+
return answer
|
203 |
+
|
204 |
+
def retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path):
|
205 |
+
image_query = Image.open(image_query_path)
|
206 |
+
_, vision_model, vision_model_processor, _, _ = model_stack
|
207 |
+
inputs = vision_model_processor(images=image_query, return_tensors="pt")
|
208 |
+
outputs = vision_model(**inputs)
|
209 |
+
image_query_embeds = outputs.pooler_output
|
210 |
+
|
211 |
+
res = index.query(vector=image_query_embeds.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
|
212 |
+
|
213 |
+
metadatas = []
|
214 |
+
for id_, match_ in enumerate(res['matches']):
|
215 |
+
result = index.fetch(ids=[match_['id']])
|
216 |
+
|
217 |
+
# Extract the vector data
|
218 |
+
vector_data = result.vectors.get(match_['id'], {})
|
219 |
+
|
220 |
+
# Extract metadata
|
221 |
+
metadata = vector_data.metadata
|
222 |
+
metadatas.append(metadata)
|
223 |
+
|
224 |
+
return metadatas
|
225 |
+
|
226 |
+
|
227 |
+
def answer_image_question(index, model_stack, metadatas, video_summary:str, video_path:str, query:str, image_query_path:str=None):
|
228 |
+
# search segment by image
|
229 |
+
relevant_segments = retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path)
|
230 |
+
|
231 |
+
# generate answer using those segments
|
232 |
+
return generate_w_image(query, image_query_path, relevant_segments)
|
233 |
+
|
234 |
+
|
235 |
+
def encode_image(image_path):
|
236 |
+
with open(image_path, "rb") as image_file:
|
237 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
238 |
+
|
239 |
+
|
240 |
+
def generate_w_image(query:str, image_query_path:str, relevant_metadatas):
|
241 |
+
base64_image = encode_image(image_query_path)
|
242 |
+
transcripts = [frame['transcript'] for frame in relevant_metadatas]
|
243 |
+
captions = [frame['caption'] for frame in relevant_metadatas]
|
244 |
+
# timestamps = [frame['start_time'] for frame in relevant_metadatas]
|
245 |
+
|
246 |
+
context = ""
|
247 |
+
for i in range(len(transcripts)):
|
248 |
+
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
|
249 |
+
# print(context)
|
250 |
+
|
251 |
+
client = OpenAI()
|
252 |
+
response = client.chat.completions.create(
|
253 |
+
model="gpt-4o-mini",
|
254 |
+
messages=[
|
255 |
+
{"role": "user", "content": [
|
256 |
+
{"type": "text", "text": f"Here is some context about the image: {context}"}, # Add context here
|
257 |
+
{"type": "text", "text": "You are a helpful LLM assistant. You are good at answering questions about a video given an image. Given the context surrounding the frames most correlated with the image and image, please answer the question. Question: {query}"},
|
258 |
+
{"type": "image_url", "image_url": {
|
259 |
+
"url": f"data:image/png;base64,{base64_image}"
|
260 |
+
}
|
261 |
+
}
|
262 |
+
]}
|
263 |
+
],
|
264 |
+
temperature=0.0,
|
265 |
+
max_tokens=100,
|
266 |
+
)
|
267 |
+
|
268 |
+
response = response.choices[0].message.content
|
269 |
+
# print(response)
|
270 |
+
return response
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai-whisper
|
2 |
+
webvtt-py
|
3 |
+
pytubefix
|
4 |
+
sentence-transformers
|
5 |
+
pinecone
|
6 |
+
gradio
|
7 |
+
moviepy
|
8 |
+
youtube-transcript-api
|
9 |
+
pytube
|
10 |
+
ffmpeg-python
|
11 |
+
ffmpeg
|
12 |
+
opencv-python
|
13 |
+
langchain_yt_dlp
|
14 |
+
langchain_community
|
15 |
+
transformers==4.49.0
|
16 |
+
numpy==1.26.4
|
17 |
+
openai==1.68.2
|
utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from io import StringIO, BytesIO
|
3 |
+
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
|
4 |
+
import base64
|
5 |
+
import glob
|
6 |
+
from tqdm import tqdm
|
7 |
+
from pytubefix import YouTube, Stream
|
8 |
+
import cv2
|
9 |
+
import json
|
10 |
+
|
11 |
+
# Taken from the course: https://www.deeplearning.ai/short-courses/multimodal-rag-chat-with-videos/
|
12 |
+
def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
|
13 |
+
segmentStream = StringIO()
|
14 |
+
|
15 |
+
if format == 'vtt':
|
16 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
17 |
+
elif format == 'srt':
|
18 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
19 |
+
else:
|
20 |
+
raise Exception("Unknown format " + format)
|
21 |
+
|
22 |
+
segmentStream.seek(0)
|
23 |
+
return segmentStream.read()
|
24 |
+
|
25 |
+
def download_video(video_url, path='/tmp/'):
|
26 |
+
print(f'Getting video information for {video_url}')
|
27 |
+
if not video_url.startswith('http'):
|
28 |
+
return os.path.join(path, video_url)
|
29 |
+
|
30 |
+
filepath = glob.glob(os.path.join(path, '*.mp4'))
|
31 |
+
if len(filepath) > 0:
|
32 |
+
return filepath[0]
|
33 |
+
|
34 |
+
def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
|
35 |
+
pbar.update(len(data_chunk))
|
36 |
+
|
37 |
+
yt = YouTube(video_url, on_progress_callback=progress_callback)
|
38 |
+
stream = yt.streams.filter(progressive=True, file_extension='mp4', res='720p').desc().first()
|
39 |
+
if stream is None:
|
40 |
+
stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
41 |
+
if not os.path.exists(path):
|
42 |
+
os.makedirs(path)
|
43 |
+
|
44 |
+
filepath = os.path.join(path, stream.default_filename)
|
45 |
+
if not os.path.exists(filepath):
|
46 |
+
print('Downloading video from YouTube...')
|
47 |
+
pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
|
48 |
+
stream.download(path)
|
49 |
+
pbar.close()
|
50 |
+
return filepath
|
51 |
+
|
52 |
+
# a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
|
53 |
+
def str2time(strtime):
|
54 |
+
# strip character " if exists
|
55 |
+
strtime = strtime.strip('"')
|
56 |
+
# get hour, minute, second from time string
|
57 |
+
hrs, mins, seconds = [float(c) for c in strtime.split(':')]
|
58 |
+
# get the corresponding time as total seconds
|
59 |
+
total_seconds = hrs * 60**2 + mins * 60 + seconds
|
60 |
+
total_miliseconds = total_seconds * 1000
|
61 |
+
return total_miliseconds
|
62 |
+
|
63 |
+
# Resizes a image and maintains aspect ratio
|
64 |
+
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
65 |
+
# Grab the image size and initialize dimensions
|
66 |
+
dim = None
|
67 |
+
(h, w) = image.shape[:2]
|
68 |
+
|
69 |
+
# Return original image if no need to resize
|
70 |
+
if width is None and height is None:
|
71 |
+
return image
|
72 |
+
|
73 |
+
# We are resizing height if width is none
|
74 |
+
if width is None:
|
75 |
+
# Calculate the ratio of the height and construct the dimensions
|
76 |
+
r = height / float(h)
|
77 |
+
dim = (int(w * r), height)
|
78 |
+
# We are resizing width if height is none
|
79 |
+
else:
|
80 |
+
# Calculate the ratio of the width and construct the dimensions
|
81 |
+
r = width / float(w)
|
82 |
+
dim = (width, int(h * r))
|
83 |
+
|
84 |
+
# Return the resized image
|
85 |
+
return cv2.resize(image, dim, interpolation=inter)
|
86 |
+
|
87 |
+
def load_json_file(file_path):
|
88 |
+
# Open the JSON file in read mode
|
89 |
+
with open(file_path, 'r') as file:
|
90 |
+
data = json.load(file)
|
91 |
+
return data
|
video_utils.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
|
2 |
+
from utils import getSubs, str2time, maintain_aspect_ratio_resize
|
3 |
+
from moviepy import VideoFileClip
|
4 |
+
import whisper
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import webvtt
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
import json
|
11 |
+
from langchain_yt_dlp.youtube_loader import YoutubeLoaderDL
|
12 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
13 |
+
|
14 |
+
|
15 |
+
# get video_metdata
|
16 |
+
def get_video_metdata(video_url:str):
|
17 |
+
docs = YoutubeLoaderDL.from_youtube_url(video_url, add_video_info=True).load()
|
18 |
+
return docs[0].metadata
|
19 |
+
|
20 |
+
# extract audio
|
21 |
+
def extract_audio(path_to_video:str, output_folder:str):
|
22 |
+
video_name = os.path.basename(path_to_video).replace('.mp4', '')
|
23 |
+
|
24 |
+
# declare where to save .mp3 audio
|
25 |
+
path_to_extracted_audio_file = os.path.join(output_folder, f'{video_name}.mp3')
|
26 |
+
|
27 |
+
# extract mp3 audio file from mp4 video video file
|
28 |
+
clip = VideoFileClip(path_to_video)
|
29 |
+
clip.audio.write_audiofile(path_to_extracted_audio_file)
|
30 |
+
return path_to_extracted_audio_file
|
31 |
+
|
32 |
+
|
33 |
+
# Get video transcript
|
34 |
+
def transcribe_video(path_to_extracted_audio_file, output_folder, whisper_model=None):
|
35 |
+
# load model
|
36 |
+
if whisper_model is None:
|
37 |
+
whisper_model = whisper.load_model("small")
|
38 |
+
options = dict(task="translate", best_of=1, language='en')
|
39 |
+
results = whisper_model.transcribe(path_to_extracted_audio_file, **options)
|
40 |
+
|
41 |
+
vtt = getSubs(results["segments"], "vtt")
|
42 |
+
# path to save generated transcript of video1
|
43 |
+
video_name = os.path.basename(path_to_video).replace('.mp4', '')
|
44 |
+
path_to_generated_transcript = os.path.join(output_folder, f'{video_name}.vtt')
|
45 |
+
|
46 |
+
# write transcription to file
|
47 |
+
with open(path_to_generated_transcript, 'w') as f:
|
48 |
+
f.write(vtt)
|
49 |
+
return path_to_generated_transcript
|
50 |
+
|
51 |
+
|
52 |
+
# get video frames & metadata
|
53 |
+
def extract_and_save_frames_and_metadata(
|
54 |
+
path_to_video,
|
55 |
+
path_to_transcript,
|
56 |
+
path_to_save_extracted_frames,
|
57 |
+
path_to_save_metadatas):
|
58 |
+
|
59 |
+
# metadatas will store the metadata of all extracted frames
|
60 |
+
metadatas = []
|
61 |
+
|
62 |
+
# load video using cv2
|
63 |
+
video = cv2.VideoCapture(path_to_video)
|
64 |
+
# load transcript using webvtt
|
65 |
+
trans = webvtt.read(path_to_transcript)
|
66 |
+
|
67 |
+
# iterate transcript file
|
68 |
+
# for each video segment specified in the transcript file
|
69 |
+
for idx, transcript in enumerate(trans):
|
70 |
+
# get the start time and end time in seconds
|
71 |
+
start_time_ms = str2time(transcript.start)
|
72 |
+
end_time_ms = str2time(transcript.end)
|
73 |
+
# get the time in ms exactly
|
74 |
+
# in the middle of start time and end time
|
75 |
+
mid_time_ms = (end_time_ms + start_time_ms) / 2
|
76 |
+
# get the transcript, remove the next-line symbol
|
77 |
+
text = transcript.text.replace("\n", ' ')
|
78 |
+
# get frame at the middle time
|
79 |
+
video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
|
80 |
+
success, frame = video.read()
|
81 |
+
if success:
|
82 |
+
# if the frame is extracted successfully, resize it
|
83 |
+
image = maintain_aspect_ratio_resize(frame, height=350)
|
84 |
+
# save frame as JPEG file
|
85 |
+
img_fname = f'frame_{idx}.jpg'
|
86 |
+
img_fpath = os.path.join(
|
87 |
+
path_to_save_extracted_frames, img_fname
|
88 |
+
)
|
89 |
+
cv2.imwrite(img_fpath, image)
|
90 |
+
|
91 |
+
# prepare the metadata
|
92 |
+
metadata = {
|
93 |
+
'extracted_frame_path': img_fpath,
|
94 |
+
'transcript': text,
|
95 |
+
'video_segment_id': idx,
|
96 |
+
'video_path': path_to_video,
|
97 |
+
'start_time': transcript.start,
|
98 |
+
'end_time': transcript.end
|
99 |
+
}
|
100 |
+
metadatas.append(metadata)
|
101 |
+
else:
|
102 |
+
print(f"ERROR! Cannot extract frame: idx = {idx}")
|
103 |
+
|
104 |
+
# add back and forth to eliminate the problem of disjointed transcript
|
105 |
+
metadatas = update_transcript(metadatas)
|
106 |
+
|
107 |
+
# save metadata of all extracted frames
|
108 |
+
fn = os.path.join(path_to_save_metadatas, 'metadatas.json')
|
109 |
+
with open(fn, 'w') as outfile:
|
110 |
+
json.dump(metadatas, outfile)
|
111 |
+
return metadatas
|
112 |
+
|
113 |
+
|
114 |
+
def update_transcript(vid_metadata, n=7):
|
115 |
+
vid_trans = [frame['transcript'] for frame in vid_metadata]
|
116 |
+
updated_vid_trans = [
|
117 |
+
' '.join(vid_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
|
118 |
+
' '.join(vid_trans[0 : i + int(n/2)]) for i in range(len(vid_trans))
|
119 |
+
]
|
120 |
+
|
121 |
+
# also need to update the updated transcripts in metadata
|
122 |
+
for i in range(len(updated_vid_trans)):
|
123 |
+
vid_metadata[i]['transcript'] = updated_vid_trans[i]
|
124 |
+
return vid_metadata
|
125 |
+
|
126 |
+
|
127 |
+
# get video caption
|
128 |
+
def get_video_caption(path_to_video_frames: List, metadatas, output_folder_path:str, vlm=None, vlm_processor=None):
|
129 |
+
if vlm is None or vlm_processor is None:
|
130 |
+
vlm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
131 |
+
vlm = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
132 |
+
|
133 |
+
frame_caption = {}
|
134 |
+
for i, frame_path in enumerate(tqdm(path_to_video_frames, desc="Captioning frames")):
|
135 |
+
|
136 |
+
frame = Image.open(frame_path)
|
137 |
+
inputs = vlm_processor(frame, return_tensors="pt")
|
138 |
+
|
139 |
+
out = vlm.generate(**inputs)
|
140 |
+
caption = vlm_processor.decode(out[0], skip_special_tokens=True)
|
141 |
+
frame_caption[frame_path] = caption
|
142 |
+
|
143 |
+
caption_out_path = os.path.join(output_folder_path, 'captions.json')
|
144 |
+
with open(caption_out_path, 'w') as outfile:
|
145 |
+
json.dump(frame_caption, outfile)
|
146 |
+
|
147 |
+
# save video caption to metadata
|
148 |
+
for frame_metadata in metadatas:
|
149 |
+
frame_metadata['caption'] = frame_caption[frame_metadata['extracted_frame_path']]
|
150 |
+
|
151 |
+
metadatas_out_path = os.path.join(output_folder_path, 'metadatas.json')
|
152 |
+
with open(metadatas_out_path, 'w') as outfile:
|
153 |
+
json.dump(metadatas, outfile)
|
154 |
+
return metadatas_out_path
|
155 |
+
|
156 |
+
|