Thao Pham commited on
Commit
d50ce1c
·
1 Parent(s): 1d5379f

First commit

Browse files
Files changed (7) hide show
  1. .gitignore +5 -0
  2. app.py +213 -0
  3. embed.py +62 -0
  4. rag.py +270 -0
  5. requirements.txt +17 -0
  6. utils.py +91 -0
  7. video_utils.py +156 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .vscode
2
+ .env
3
+ __pycache__
4
+ tmp
5
+ uploads
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import re
4
+ import video_utils
5
+ import utils
6
+ import embed
7
+ import rag
8
+ import os
9
+ import uuid
10
+ import numpy as np
11
+ from pinecone import Pinecone, ServerlessSpec
12
+ from sentence_transformers import SentenceTransformer
13
+ from transformers import AutoImageProcessor, AutoModel
14
+ from transformers import BlipProcessor, BlipForConditionalGeneration
15
+ from dotenv import load_dotenv
16
+
17
+ load_dotenv() # Load from .env
18
+
19
+ UPLOAD_FOLDER = 'uploads'
20
+ video_name = None
21
+
22
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
23
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
24
+
25
+ # init models
26
+ TEXT_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
27
+ VISION_MODEL_PROCESSOR = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
28
+ VISION_MODEL = AutoModel.from_pretrained('facebook/dinov2-small')
29
+
30
+ VLM_PROCESSOR = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
31
+ VLM = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
32
+
33
+ # init index
34
+ pc = Pinecone(
35
+ api_key=PINECONE_API_KEY
36
+ )
37
+ # Connect to an index
38
+ index_name = "mutlimodal-minilm"
39
+ INDEX = pc.Index(index_name)
40
+ MODEL_STACK = [TEXT_MODEL, VISION_MODEL, VISION_MODEL_PROCESSOR, VLM, VLM_PROCESSOR]
41
+
42
+
43
+ def is_valid_youtube_url(url):
44
+ """
45
+ Checks if the given URL is a valid YouTube video URL.
46
+
47
+ Returns True if valid, False otherwise.
48
+ """
49
+ youtube_regex = re.compile(
50
+ r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/"
51
+ r"(watch\?v=|embed/|v/|shorts/)?([a-zA-Z0-9_-]{11})"
52
+ )
53
+
54
+ match = youtube_regex.match(url)
55
+ return bool(match)
56
+
57
+
58
+ def check_exist_before_upsert(index, video_path):
59
+ # threshold = len(frames) * 3
60
+ threshold = [elem for elem in os.listdir(video_path.split('/')[0]) if elem.endswith('.jpg')]
61
+ threshold = len(threshold)*3 # image embeds, caption embeds, transcript embeds
62
+
63
+ dimension = 384
64
+ res = index.query(
65
+ vector=[0]*dimension, # Dummy vector (not used for filtering)
66
+ top_k=10000, # Set a high value to retrieve as many matches as possible
67
+ filter={"video_path": video_path} # Filter by video_path
68
+ )
69
+
70
+ # Count the number of matching vectors
71
+ num_existing_vectors = len(res["matches"])
72
+
73
+ if num_existing_vectors >= threshold:
74
+ return True
75
+ return False
76
+
77
+
78
+ def chat(message, history):
79
+ image_input_path = None
80
+ if len(message['files']) > 0:
81
+ assert len(message['files']) == 1
82
+ image_input_path = message['files'][0]
83
+
84
+ message = message['text']
85
+
86
+ if history is None:
87
+ history = []
88
+
89
+ if message.startswith("https://"):
90
+ # Check valid URL
91
+ history.append((message, f"Checking if your provided URL at {message} is valid..."))
92
+ yield history
93
+
94
+ valid = is_valid_youtube_url(message)
95
+ if not valid:
96
+ history.append((None, "❌ Invalid YouTube URL. Please try again."))
97
+ yield history
98
+ return
99
+
100
+ # Check metadata
101
+ history.append((None, "✅ URL is valid! Fetching video metadata..."))
102
+ yield history
103
+
104
+ video_metadata = video_utils.get_video_metdata(message)
105
+ history.append((None, f"The video you want to process is: \nTitle: {video_metadata['title']} published by {video_metadata['author']} on {video_metadata['publish_date']}."))
106
+ yield history
107
+
108
+ history.append((None, "⏳ Downloading video..."))
109
+ yield history
110
+
111
+ output_folder_path = os.path.join(UPLOAD_FOLDER, video_metadata['title'])
112
+ path_to_video = os.path.join(output_folder_path, f"video.mp4")
113
+ if not os.path.exists(path_to_video):
114
+ path_to_video = utils.download_video(message, path=output_folder_path)
115
+
116
+ history.append((None, "⏳ Transcribing video..."))
117
+ yield history
118
+ path_to_audio_file = os.path.join(output_folder_path, f"audio.mp3")
119
+ if not os.path.exists(path_to_audio_file):
120
+ path_to_audio_file = video_utils.extract_audio(path_to_video, output_folder_path)
121
+
122
+ path_to_generated_transcript = os.path.join(output_folder_path, f"transcript.vtt")
123
+ if not os.path.exists(path_to_generated_transcript):
124
+ path_to_generated_transcript = video_utils.transcribe_video(path_to_audio_file, output_folder_path)
125
+
126
+ # extract frames and metadata
127
+ metadatas_path = os.path.join(output_folder_path, 'metadatas.json')
128
+ if not os.path.exists(metadatas_path):
129
+ metadatas = video_utils.extract_and_save_frames_and_metadata(path_to_video=path_to_video,
130
+ path_to_transcript=path_to_generated_transcript,
131
+ path_to_save_extracted_frames=output_folder_path,
132
+ path_to_save_metadatas=output_folder_path)
133
+
134
+ history.append((None, "⏳ Captioning video..."))
135
+ yield history
136
+
137
+ caption_path = os.path.join(output_folder_path, 'captions.json')
138
+ if not os.path.exists(caption_path):
139
+ video_frames = [os.path.join(output_folder_path, elem) for elem in os.listdir(output_folder_path) if elem.endswith('.jpg')]
140
+ metadata_path = video_utils.get_video_caption(video_frames, metadatas, output_folder_path, vlm=VLM, vlm_processor=VLM_PROCESSOR)
141
+
142
+ history.append((None, "⏳ Indexing..."))
143
+ yield history
144
+ index_exist = check_exist_before_upsert(INDEX, path_to_video)
145
+ print(index_exist)
146
+ if not index_exist:
147
+ embed.indexing(INDEX, MODEL_STACK, metadatas_path)
148
+
149
+ # summarizing video
150
+ video_summary = rag.summarize_video(metadatas_path)
151
+ with open(os.path.join(output_folder_path, "summary.txt"), "w") as f:
152
+ f.write(video_summary)
153
+
154
+ history.append((None, f"Video processing complete! You can now ask me questions about the video {video_metadata['title']}!"))
155
+ yield history
156
+
157
+ global video_name
158
+ video_name = video_metadata['title']
159
+ else:
160
+ history.append((message, None))
161
+ yield history
162
+
163
+ if video_name is None:
164
+ history.append((None, "You need to insert video URL before asking questions."))
165
+ yield history
166
+ return
167
+
168
+ output_folder_path = f"{UPLOAD_FOLDER}/{video_name}"
169
+ metadatas_path = os.path.join(output_folder_path, 'metadatas.json')
170
+
171
+ video_summary = ''
172
+ with open(f'./{output_folder_path}/summary.txt') as f:
173
+ while True:
174
+ ln = f.readline()
175
+ if ln == '':
176
+ break
177
+ video_summary += ln.strip()
178
+ video_path = os.path.join(output_folder_path, 'video.mp4')
179
+ answer = rag.answer_question(INDEX, MODEL_STACK, metadatas_path, video_summary, video_path, message, image_input_path)
180
+
181
+ history.append((None, answer))
182
+ yield history
183
+
184
+ def clear_chat(history):
185
+ # return []
186
+ history = []
187
+ history.append((None, "Please input a Youtube URL to get started!"))
188
+ # yield history
189
+ return history
190
+
191
+ def main():
192
+ initial_messages = [(None, "Please input a Youtube URL to get started!")]
193
+
194
+ with gr.Blocks() as demo:
195
+ chatbot = gr.Chatbot(value=initial_messages)
196
+ msg = gr.MultimodalTextbox(file_types=['image'], sources=['upload'])
197
+
198
+ with gr.Row():
199
+ with gr.Column():
200
+ submit = gr.Button("Send")
201
+ submit.click(chat, [msg, chatbot], chatbot)
202
+
203
+ with gr.Column():
204
+ clear = gr.Button("Clear") # Clear button
205
+ # Clear chat history when clear button is clicked
206
+ clear.click(clear_chat, [], chatbot)
207
+ global video_name
208
+ video_name = None
209
+
210
+ demo.launch()
211
+
212
+ if __name__ == "__main__":
213
+ main()
embed.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from transformers import AutoImageProcessor, AutoModel
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
6
+ import numpy as np
7
+ import uuid
8
+ from utils import load_json_file
9
+
10
+ def embed_texts(text_ls:List[str], text_model=None):
11
+ if text_model is None:
12
+ text_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
13
+
14
+ text_embeddings = []
15
+ for i, text in enumerate(tqdm(text_ls, desc="Embedding text")):
16
+ embeds = text_model.encode(text)
17
+ text_embeddings.append(embeds)
18
+ return np.array(text_embeddings)
19
+
20
+
21
+ def embed_images(image_path_ls:List[str], vision_model=None, vision_model_processor=None):
22
+ if vision_model is None or vision_model_processor is None:
23
+ vision_model_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
24
+ vision_model = AutoModel.from_pretrained('facebook/dinov2-small')
25
+
26
+ image_embeds_ls = []
27
+ for i, frame in enumerate(tqdm(image_path_ls, desc="Embedding image")):
28
+ frame = Image.open(frame)
29
+ # TODO: add device here
30
+ inputs = vision_model_processor(images=frame, return_tensors="pt")
31
+ outputs = vision_model(**inputs)
32
+ image_embeds_ls.append(outputs.pooler_output)
33
+ return np.array([elem.squeeze().detach().numpy() for elem in image_embeds_ls])
34
+
35
+
36
+ def indexing(index, model_stack, vid_metadata_path):
37
+ text_model, vision_model, vision_model_processor, _, _ = model_stack
38
+
39
+ # read metadata file
40
+ vid_metadata = load_json_file(vid_metadata_path)
41
+
42
+ # embed transcripts
43
+ vid_trans = [frame['transcript'] for frame in vid_metadata]
44
+ transcript_embeddings = embed_texts(text_ls=vid_trans, text_model=text_model)
45
+
46
+ # embed caption
47
+ vid_captions = [frame['caption'] for frame in vid_metadata]
48
+ caption_embeddings = embed_texts(text_ls=vid_captions, text_model=text_model)
49
+
50
+ # embed frames
51
+ vid_img_paths = [vid['extracted_frame_path'] for vid in vid_metadata]
52
+ frame_embeddings = embed_images(vid_img_paths, vision_model, vision_model_processor)
53
+
54
+ for ls in [transcript_embeddings, caption_embeddings, frame_embeddings]:
55
+ # Prepare metadata
56
+ vectors = [
57
+ (str(uuid.uuid4()), emb.tolist(), meta) # Generate unique IDs
58
+ for emb, meta in zip(ls, vid_metadata)
59
+ ]
60
+ # Upsert vectors into Pinecone
61
+ index.upsert(vectors)
62
+
rag.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from langchain_community.chat_models import ChatOpenAI
3
+ from langchain.chains import LLMChain
4
+ from PIL import Image
5
+ import os
6
+ from utils import load_json_file, str2time
7
+ from openai import OpenAI
8
+ import base64
9
+
10
+ def get_smallest_timestamp(timestamps):
11
+ assert len(timestamps) > 0
12
+
13
+ timestamps_in_ms = [str2time(elem) for elem in timestamps]
14
+
15
+ smallest_timestamp_in_ms = timestamps_in_ms[0]
16
+ smallest_timestamp = timestamps[0]
17
+ for i, elem in enumerate(timestamps_in_ms):
18
+ if elem < smallest_timestamp_in_ms:
19
+ smallest_timestamp_in_ms = elem
20
+ smallest_timestamp = timestamps[i]
21
+ return smallest_timestamp
22
+
23
+ def generate(query, context, relevant_timestamps=None):
24
+ prompt = PromptTemplate(input_variables=["question", "context"], template="You're a helpful LLM assistant in answering questions regarding a video. Given contexts are segments relevant to the question, please answer the question. Do not refer to segments. Context: {context}, question: {question} \nA:")
25
+
26
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
27
+ chain = LLMChain(llm=llm, prompt=prompt)
28
+ response = chain.run(question=query, context=context)
29
+
30
+ if relevant_timestamps is not None and len(relevant_timestamps)>0:
31
+ # get smallest timestamp = earliest mention
32
+ smallest_timestamp = get_smallest_timestamp(relevant_timestamps)
33
+ response += f' {smallest_timestamp}'
34
+ return response
35
+
36
+
37
+ def check_relevance(query, relevant_metadatas):
38
+ transcripts = [frame['transcript'] for frame in relevant_metadatas]
39
+ captions = [frame['caption'] for frame in relevant_metadatas]
40
+ timestamps = [frame['start_time'] for frame in relevant_metadatas]
41
+
42
+ context = ""
43
+ for i in range(len(transcripts)):
44
+ context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
45
+ # print(context)
46
+
47
+ prompt = PromptTemplate(input_variables=["question", "context"], template="""
48
+ You are a grader assessing relevance of a retrieved video segment to a user question. \n
49
+ If the video segment contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
50
+ Give a binary score 'yes' or 'no' score to indicate whether the video segment is relevant to the question. \n
51
+ Answer in a string, separated by commas. For example: if there are segments provided, answer: yes,no,no,yes. \n
52
+ Question: {question} Context: {context}\n A:""")
53
+
54
+ # query = "What are the books mentioned in the video?"
55
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
56
+ chain = LLMChain(llm=llm, prompt=prompt)
57
+ response = chain.run(question=query, context=context)
58
+ # print(response)
59
+
60
+ relevance_response = response.split(',')
61
+
62
+ actual_relevant_context = ""
63
+ relevant_timestamps = []
64
+ for i, relevance_check in enumerate(relevance_response):
65
+ if relevance_check.strip() == 'yes':
66
+ actual_relevant_context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
67
+ relevant_timestamps.append(timestamps[i])
68
+ return actual_relevant_context, relevant_timestamps
69
+
70
+
71
+ def retrieve_segments_from_timestamp(metadatas, timestamps):
72
+ relevant_segments = []
73
+
74
+ for timestamp in timestamps:
75
+ time_to_find_ms = str2time(timestamp)
76
+ buffer = 5000 # 5 seconds before and after
77
+
78
+ for segment in metadatas:
79
+ start = str2time(segment['start_time'])
80
+ end = str2time(segment['end_time'])
81
+ if start <= time_to_find_ms + buffer and end >= time_to_find_ms - buffer:
82
+ relevant_segments.append(segment)
83
+
84
+ return relevant_segments
85
+
86
+
87
+ def check_timestamps(query):
88
+ prompt = PromptTemplate(input_variables=["question"], template="You're a helpful LLM assistant. You're good at detecting any timestamps provided in a query. Please detect the question and timestamp in the the following question and separated them by commas such as question,timestamp1,timestamp2 if timestamps are provided else just question. Question: {question} \nA:")
89
+
90
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
91
+ chain = LLMChain(llm=llm, prompt=prompt)
92
+ response = chain.run(question=query)
93
+
94
+ timestamps = []
95
+ if len(response.split(',')) > 1:
96
+ query = response.split(',')[0].strip()
97
+ timestamps = [f"00:{elem.strip()}.00" for elem in response.split(',')[1:]]
98
+
99
+ return query, timestamps
100
+
101
+ def retrieve_by_embedding(index, video_path, query, text_model):
102
+ print(query)
103
+ query_embedding = text_model.encode(query)
104
+
105
+ res = index.query(vector=query_embedding.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
106
+
107
+ metadatas = []
108
+ for id, match_ in enumerate(res['matches']):
109
+ result = index.fetch(ids=[match_['id']])
110
+
111
+ # Extract the vector data
112
+ vector_data = result.vectors.get(match_['id'], {})
113
+
114
+ # Extract metadata
115
+ metadata = vector_data.metadata
116
+ metadatas.append(metadata)
117
+
118
+ return metadatas
119
+
120
+ def self_reflection(query, answer, summary):
121
+ prompt = PromptTemplate(input_variables=["summary", "question", "answer"], template="You're a helpful LLM assistant. You're good at determining if the provided answer is satisfactory to a question relating to a video. You have access to the video summary as follows: {summary}. Given a pair of question and answer, give the answer's satisfactory score in either yes or no. Question: {question}, Answer: {answer} \nA:")
122
+
123
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
124
+ chain = LLMChain(llm=llm, prompt=prompt)
125
+ response = chain.run(summary=summary, question=query, answer=answer)
126
+ return response
127
+
128
+
129
+ def get_full_transcript(metadatas):
130
+ # metadatas = webvtt.read(path_to_transcript)
131
+ transcripts = [frame['transcript'] for frame in metadatas]
132
+
133
+ full_text = ''
134
+ for idx, transcript in enumerate(transcripts):
135
+ text = transcript.strip().replace(" ", " ")
136
+ full_text += f"{text} "
137
+
138
+ full_text = full_text.strip()
139
+ return full_text
140
+
141
+ def summarize_video(metadatas_path:str):
142
+ metadatas = load_json_file(metadatas_path)
143
+
144
+ # get full transcript
145
+ transcript = get_full_transcript(metadatas)
146
+ prompt = PromptTemplate(input_variables=["transcript"], template="You're a helpful LLM assistant. Please provide a summary for the video given its full transcript: {transcript} \nA:")
147
+
148
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
149
+ chain = LLMChain(llm=llm, prompt=prompt)
150
+ response = chain.run(transcript=transcript)
151
+ return response
152
+
153
+ def answer_wrt_timestamp(query, context):
154
+ prompt = PromptTemplate(input_variables=["question", "context"], template="""
155
+ You're a helpful LLM assistant. Given a question and a timestamp, I have retrieved the relevant context as follows. Please answer the question using the information provided in the context. Question: {question}, context: {context} \n
156
+ For example: Question="What happens at 4:20?" Caption="a person is standing up" Transcript="I have to go" Appropriate Answer="At 4:20, a person is standing up and saying he has to go."
157
+ A:""")
158
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
159
+ chain = LLMChain(llm=llm, prompt=prompt)
160
+ response = chain.run(question=query, context=context)
161
+ return response
162
+
163
+
164
+ def answer_question(index, model_stack, metadatas_path, video_summary:str, video_path:str, query:str, image_input_path:str=None):
165
+ metadatas = load_json_file(metadatas_path)
166
+ if image_input_path is not None:
167
+ return answer_image_question(index, model_stack, metadatas, video_summary, video_path, query, image_input_path)
168
+
169
+ # check if timestamp provided
170
+ query, timestamps = check_timestamps(query)
171
+
172
+ if len(timestamps) > 0:
173
+ # retrieve by timestamps
174
+ relevant_segments_metadatas = retrieve_segments_from_timestamp(metadatas, timestamps)
175
+ transcripts = [frame['transcript'] for frame in relevant_segments_metadatas]
176
+ captions = [frame['caption'] for frame in relevant_segments_metadatas]
177
+ context = ""
178
+ for i in range(len(transcripts)):
179
+ context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
180
+ # print(context)
181
+ return answer_wrt_timestamp(query, context)
182
+ else:
183
+ # retrieve by embedding
184
+ relevant_segments_metadatas = retrieve_by_embedding(index, video_path, query, model_stack[0])
185
+
186
+ # check relevance
187
+ actual_relevant_context, relevant_timestamps = check_relevance(query, relevant_segments_metadatas)
188
+ # relevant_timestamps = [frame['start_time'] for frame in relevant_segments_metadatas]
189
+ # print(actual_relevant_context)
190
+
191
+ # generate
192
+ answer = generate(query, actual_relevant_context, relevant_timestamps)
193
+ # print(answer)
194
+
195
+ # self-reflection
196
+ reflect = self_reflection(query, answer, video_summary)
197
+
198
+ # print("Reflect", reflect)
199
+ if reflect.lower() == 'no':
200
+ answer = generate(query, f"{actual_relevant_context}\nSummary={video_summary}")
201
+
202
+ return answer
203
+
204
+ def retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path):
205
+ image_query = Image.open(image_query_path)
206
+ _, vision_model, vision_model_processor, _, _ = model_stack
207
+ inputs = vision_model_processor(images=image_query, return_tensors="pt")
208
+ outputs = vision_model(**inputs)
209
+ image_query_embeds = outputs.pooler_output
210
+
211
+ res = index.query(vector=image_query_embeds.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
212
+
213
+ metadatas = []
214
+ for id_, match_ in enumerate(res['matches']):
215
+ result = index.fetch(ids=[match_['id']])
216
+
217
+ # Extract the vector data
218
+ vector_data = result.vectors.get(match_['id'], {})
219
+
220
+ # Extract metadata
221
+ metadata = vector_data.metadata
222
+ metadatas.append(metadata)
223
+
224
+ return metadatas
225
+
226
+
227
+ def answer_image_question(index, model_stack, metadatas, video_summary:str, video_path:str, query:str, image_query_path:str=None):
228
+ # search segment by image
229
+ relevant_segments = retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path)
230
+
231
+ # generate answer using those segments
232
+ return generate_w_image(query, image_query_path, relevant_segments)
233
+
234
+
235
+ def encode_image(image_path):
236
+ with open(image_path, "rb") as image_file:
237
+ return base64.b64encode(image_file.read()).decode("utf-8")
238
+
239
+
240
+ def generate_w_image(query:str, image_query_path:str, relevant_metadatas):
241
+ base64_image = encode_image(image_query_path)
242
+ transcripts = [frame['transcript'] for frame in relevant_metadatas]
243
+ captions = [frame['caption'] for frame in relevant_metadatas]
244
+ # timestamps = [frame['start_time'] for frame in relevant_metadatas]
245
+
246
+ context = ""
247
+ for i in range(len(transcripts)):
248
+ context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
249
+ # print(context)
250
+
251
+ client = OpenAI()
252
+ response = client.chat.completions.create(
253
+ model="gpt-4o-mini",
254
+ messages=[
255
+ {"role": "user", "content": [
256
+ {"type": "text", "text": f"Here is some context about the image: {context}"}, # Add context here
257
+ {"type": "text", "text": "You are a helpful LLM assistant. You are good at answering questions about a video given an image. Given the context surrounding the frames most correlated with the image and image, please answer the question. Question: {query}"},
258
+ {"type": "image_url", "image_url": {
259
+ "url": f"data:image/png;base64,{base64_image}"
260
+ }
261
+ }
262
+ ]}
263
+ ],
264
+ temperature=0.0,
265
+ max_tokens=100,
266
+ )
267
+
268
+ response = response.choices[0].message.content
269
+ # print(response)
270
+ return response
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai-whisper
2
+ webvtt-py
3
+ pytubefix
4
+ sentence-transformers
5
+ pinecone
6
+ gradio
7
+ moviepy
8
+ youtube-transcript-api
9
+ pytube
10
+ ffmpeg-python
11
+ ffmpeg
12
+ opencv-python
13
+ langchain_yt_dlp
14
+ langchain_community
15
+ transformers==4.49.0
16
+ numpy==1.26.4
17
+ openai==1.68.2
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import StringIO, BytesIO
3
+ from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
4
+ import base64
5
+ import glob
6
+ from tqdm import tqdm
7
+ from pytubefix import YouTube, Stream
8
+ import cv2
9
+ import json
10
+
11
+ # Taken from the course: https://www.deeplearning.ai/short-courses/multimodal-rag-chat-with-videos/
12
+ def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
13
+ segmentStream = StringIO()
14
+
15
+ if format == 'vtt':
16
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
17
+ elif format == 'srt':
18
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
19
+ else:
20
+ raise Exception("Unknown format " + format)
21
+
22
+ segmentStream.seek(0)
23
+ return segmentStream.read()
24
+
25
+ def download_video(video_url, path='/tmp/'):
26
+ print(f'Getting video information for {video_url}')
27
+ if not video_url.startswith('http'):
28
+ return os.path.join(path, video_url)
29
+
30
+ filepath = glob.glob(os.path.join(path, '*.mp4'))
31
+ if len(filepath) > 0:
32
+ return filepath[0]
33
+
34
+ def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
35
+ pbar.update(len(data_chunk))
36
+
37
+ yt = YouTube(video_url, on_progress_callback=progress_callback)
38
+ stream = yt.streams.filter(progressive=True, file_extension='mp4', res='720p').desc().first()
39
+ if stream is None:
40
+ stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
41
+ if not os.path.exists(path):
42
+ os.makedirs(path)
43
+
44
+ filepath = os.path.join(path, stream.default_filename)
45
+ if not os.path.exists(filepath):
46
+ print('Downloading video from YouTube...')
47
+ pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
48
+ stream.download(path)
49
+ pbar.close()
50
+ return filepath
51
+
52
+ # a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
53
+ def str2time(strtime):
54
+ # strip character " if exists
55
+ strtime = strtime.strip('"')
56
+ # get hour, minute, second from time string
57
+ hrs, mins, seconds = [float(c) for c in strtime.split(':')]
58
+ # get the corresponding time as total seconds
59
+ total_seconds = hrs * 60**2 + mins * 60 + seconds
60
+ total_miliseconds = total_seconds * 1000
61
+ return total_miliseconds
62
+
63
+ # Resizes a image and maintains aspect ratio
64
+ def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
65
+ # Grab the image size and initialize dimensions
66
+ dim = None
67
+ (h, w) = image.shape[:2]
68
+
69
+ # Return original image if no need to resize
70
+ if width is None and height is None:
71
+ return image
72
+
73
+ # We are resizing height if width is none
74
+ if width is None:
75
+ # Calculate the ratio of the height and construct the dimensions
76
+ r = height / float(h)
77
+ dim = (int(w * r), height)
78
+ # We are resizing width if height is none
79
+ else:
80
+ # Calculate the ratio of the width and construct the dimensions
81
+ r = width / float(w)
82
+ dim = (width, int(h * r))
83
+
84
+ # Return the resized image
85
+ return cv2.resize(image, dim, interpolation=inter)
86
+
87
+ def load_json_file(file_path):
88
+ # Open the JSON file in read mode
89
+ with open(file_path, 'r') as file:
90
+ data = json.load(file)
91
+ return data
video_utils.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
2
+ from utils import getSubs, str2time, maintain_aspect_ratio_resize
3
+ from moviepy import VideoFileClip
4
+ import whisper
5
+ import os
6
+ import cv2
7
+ import webvtt
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import json
11
+ from langchain_yt_dlp.youtube_loader import YoutubeLoaderDL
12
+ from transformers import BlipProcessor, BlipForConditionalGeneration
13
+
14
+
15
+ # get video_metdata
16
+ def get_video_metdata(video_url:str):
17
+ docs = YoutubeLoaderDL.from_youtube_url(video_url, add_video_info=True).load()
18
+ return docs[0].metadata
19
+
20
+ # extract audio
21
+ def extract_audio(path_to_video:str, output_folder:str):
22
+ video_name = os.path.basename(path_to_video).replace('.mp4', '')
23
+
24
+ # declare where to save .mp3 audio
25
+ path_to_extracted_audio_file = os.path.join(output_folder, f'{video_name}.mp3')
26
+
27
+ # extract mp3 audio file from mp4 video video file
28
+ clip = VideoFileClip(path_to_video)
29
+ clip.audio.write_audiofile(path_to_extracted_audio_file)
30
+ return path_to_extracted_audio_file
31
+
32
+
33
+ # Get video transcript
34
+ def transcribe_video(path_to_extracted_audio_file, output_folder, whisper_model=None):
35
+ # load model
36
+ if whisper_model is None:
37
+ whisper_model = whisper.load_model("small")
38
+ options = dict(task="translate", best_of=1, language='en')
39
+ results = whisper_model.transcribe(path_to_extracted_audio_file, **options)
40
+
41
+ vtt = getSubs(results["segments"], "vtt")
42
+ # path to save generated transcript of video1
43
+ video_name = os.path.basename(path_to_video).replace('.mp4', '')
44
+ path_to_generated_transcript = os.path.join(output_folder, f'{video_name}.vtt')
45
+
46
+ # write transcription to file
47
+ with open(path_to_generated_transcript, 'w') as f:
48
+ f.write(vtt)
49
+ return path_to_generated_transcript
50
+
51
+
52
+ # get video frames & metadata
53
+ def extract_and_save_frames_and_metadata(
54
+ path_to_video,
55
+ path_to_transcript,
56
+ path_to_save_extracted_frames,
57
+ path_to_save_metadatas):
58
+
59
+ # metadatas will store the metadata of all extracted frames
60
+ metadatas = []
61
+
62
+ # load video using cv2
63
+ video = cv2.VideoCapture(path_to_video)
64
+ # load transcript using webvtt
65
+ trans = webvtt.read(path_to_transcript)
66
+
67
+ # iterate transcript file
68
+ # for each video segment specified in the transcript file
69
+ for idx, transcript in enumerate(trans):
70
+ # get the start time and end time in seconds
71
+ start_time_ms = str2time(transcript.start)
72
+ end_time_ms = str2time(transcript.end)
73
+ # get the time in ms exactly
74
+ # in the middle of start time and end time
75
+ mid_time_ms = (end_time_ms + start_time_ms) / 2
76
+ # get the transcript, remove the next-line symbol
77
+ text = transcript.text.replace("\n", ' ')
78
+ # get frame at the middle time
79
+ video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
80
+ success, frame = video.read()
81
+ if success:
82
+ # if the frame is extracted successfully, resize it
83
+ image = maintain_aspect_ratio_resize(frame, height=350)
84
+ # save frame as JPEG file
85
+ img_fname = f'frame_{idx}.jpg'
86
+ img_fpath = os.path.join(
87
+ path_to_save_extracted_frames, img_fname
88
+ )
89
+ cv2.imwrite(img_fpath, image)
90
+
91
+ # prepare the metadata
92
+ metadata = {
93
+ 'extracted_frame_path': img_fpath,
94
+ 'transcript': text,
95
+ 'video_segment_id': idx,
96
+ 'video_path': path_to_video,
97
+ 'start_time': transcript.start,
98
+ 'end_time': transcript.end
99
+ }
100
+ metadatas.append(metadata)
101
+ else:
102
+ print(f"ERROR! Cannot extract frame: idx = {idx}")
103
+
104
+ # add back and forth to eliminate the problem of disjointed transcript
105
+ metadatas = update_transcript(metadatas)
106
+
107
+ # save metadata of all extracted frames
108
+ fn = os.path.join(path_to_save_metadatas, 'metadatas.json')
109
+ with open(fn, 'w') as outfile:
110
+ json.dump(metadatas, outfile)
111
+ return metadatas
112
+
113
+
114
+ def update_transcript(vid_metadata, n=7):
115
+ vid_trans = [frame['transcript'] for frame in vid_metadata]
116
+ updated_vid_trans = [
117
+ ' '.join(vid_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
118
+ ' '.join(vid_trans[0 : i + int(n/2)]) for i in range(len(vid_trans))
119
+ ]
120
+
121
+ # also need to update the updated transcripts in metadata
122
+ for i in range(len(updated_vid_trans)):
123
+ vid_metadata[i]['transcript'] = updated_vid_trans[i]
124
+ return vid_metadata
125
+
126
+
127
+ # get video caption
128
+ def get_video_caption(path_to_video_frames: List, metadatas, output_folder_path:str, vlm=None, vlm_processor=None):
129
+ if vlm is None or vlm_processor is None:
130
+ vlm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
131
+ vlm = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
132
+
133
+ frame_caption = {}
134
+ for i, frame_path in enumerate(tqdm(path_to_video_frames, desc="Captioning frames")):
135
+
136
+ frame = Image.open(frame_path)
137
+ inputs = vlm_processor(frame, return_tensors="pt")
138
+
139
+ out = vlm.generate(**inputs)
140
+ caption = vlm_processor.decode(out[0], skip_special_tokens=True)
141
+ frame_caption[frame_path] = caption
142
+
143
+ caption_out_path = os.path.join(output_folder_path, 'captions.json')
144
+ with open(caption_out_path, 'w') as outfile:
145
+ json.dump(frame_caption, outfile)
146
+
147
+ # save video caption to metadata
148
+ for frame_metadata in metadatas:
149
+ frame_metadata['caption'] = frame_caption[frame_metadata['extracted_frame_path']]
150
+
151
+ metadatas_out_path = os.path.join(output_folder_path, 'metadatas.json')
152
+ with open(metadatas_out_path, 'w') as outfile:
153
+ json.dump(metadatas, outfile)
154
+ return metadatas_out_path
155
+
156
+