File size: 11,849 Bytes
d50ce1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from PIL import Image
import os
from utils import load_json_file, str2time
from openai import OpenAI
import base64
def get_smallest_timestamp(timestamps):
assert len(timestamps) > 0
timestamps_in_ms = [str2time(elem) for elem in timestamps]
smallest_timestamp_in_ms = timestamps_in_ms[0]
smallest_timestamp = timestamps[0]
for i, elem in enumerate(timestamps_in_ms):
if elem < smallest_timestamp_in_ms:
smallest_timestamp_in_ms = elem
smallest_timestamp = timestamps[i]
return smallest_timestamp
def generate(query, context, relevant_timestamps=None):
prompt = PromptTemplate(input_variables=["question", "context"], template="You're a helpful LLM assistant in answering questions regarding a video. Given contexts are segments relevant to the question, please answer the question. Do not refer to segments. Context: {context}, question: {question} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query, context=context)
if relevant_timestamps is not None and len(relevant_timestamps)>0:
# get smallest timestamp = earliest mention
smallest_timestamp = get_smallest_timestamp(relevant_timestamps)
response += f' {smallest_timestamp}'
return response
def check_relevance(query, relevant_metadatas):
transcripts = [frame['transcript'] for frame in relevant_metadatas]
captions = [frame['caption'] for frame in relevant_metadatas]
timestamps = [frame['start_time'] for frame in relevant_metadatas]
context = ""
for i in range(len(transcripts)):
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
# print(context)
prompt = PromptTemplate(input_variables=["question", "context"], template="""
You are a grader assessing relevance of a retrieved video segment to a user question. \n
If the video segment contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the video segment is relevant to the question. \n
Answer in a string, separated by commas. For example: if there are segments provided, answer: yes,no,no,yes. \n
Question: {question} Context: {context}\n A:""")
# query = "What are the books mentioned in the video?"
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query, context=context)
# print(response)
relevance_response = response.split(',')
actual_relevant_context = ""
relevant_timestamps = []
for i, relevance_check in enumerate(relevance_response):
if relevance_check.strip() == 'yes':
actual_relevant_context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
relevant_timestamps.append(timestamps[i])
return actual_relevant_context, relevant_timestamps
def retrieve_segments_from_timestamp(metadatas, timestamps):
relevant_segments = []
for timestamp in timestamps:
time_to_find_ms = str2time(timestamp)
buffer = 5000 # 5 seconds before and after
for segment in metadatas:
start = str2time(segment['start_time'])
end = str2time(segment['end_time'])
if start <= time_to_find_ms + buffer and end >= time_to_find_ms - buffer:
relevant_segments.append(segment)
return relevant_segments
def check_timestamps(query):
prompt = PromptTemplate(input_variables=["question"], template="You're a helpful LLM assistant. You're good at detecting any timestamps provided in a query. Please detect the question and timestamp in the the following question and separated them by commas such as question,timestamp1,timestamp2 if timestamps are provided else just question. Question: {question} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query)
timestamps = []
if len(response.split(',')) > 1:
query = response.split(',')[0].strip()
timestamps = [f"00:{elem.strip()}.00" for elem in response.split(',')[1:]]
return query, timestamps
def retrieve_by_embedding(index, video_path, query, text_model):
print(query)
query_embedding = text_model.encode(query)
res = index.query(vector=query_embedding.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
metadatas = []
for id, match_ in enumerate(res['matches']):
result = index.fetch(ids=[match_['id']])
# Extract the vector data
vector_data = result.vectors.get(match_['id'], {})
# Extract metadata
metadata = vector_data.metadata
metadatas.append(metadata)
return metadatas
def self_reflection(query, answer, summary):
prompt = PromptTemplate(input_variables=["summary", "question", "answer"], template="You're a helpful LLM assistant. You're good at determining if the provided answer is satisfactory to a question relating to a video. You have access to the video summary as follows: {summary}. Given a pair of question and answer, give the answer's satisfactory score in either yes or no. Question: {question}, Answer: {answer} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(summary=summary, question=query, answer=answer)
return response
def get_full_transcript(metadatas):
# metadatas = webvtt.read(path_to_transcript)
transcripts = [frame['transcript'] for frame in metadatas]
full_text = ''
for idx, transcript in enumerate(transcripts):
text = transcript.strip().replace(" ", " ")
full_text += f"{text} "
full_text = full_text.strip()
return full_text
def summarize_video(metadatas_path:str):
metadatas = load_json_file(metadatas_path)
# get full transcript
transcript = get_full_transcript(metadatas)
prompt = PromptTemplate(input_variables=["transcript"], template="You're a helpful LLM assistant. Please provide a summary for the video given its full transcript: {transcript} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(transcript=transcript)
return response
def answer_wrt_timestamp(query, context):
prompt = PromptTemplate(input_variables=["question", "context"], template="""
You're a helpful LLM assistant. Given a question and a timestamp, I have retrieved the relevant context as follows. Please answer the question using the information provided in the context. Question: {question}, context: {context} \n
For example: Question="What happens at 4:20?" Caption="a person is standing up" Transcript="I have to go" Appropriate Answer="At 4:20, a person is standing up and saying he has to go."
A:""")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query, context=context)
return response
def answer_question(index, model_stack, metadatas_path, video_summary:str, video_path:str, query:str, image_input_path:str=None):
metadatas = load_json_file(metadatas_path)
if image_input_path is not None:
return answer_image_question(index, model_stack, metadatas, video_summary, video_path, query, image_input_path)
# check if timestamp provided
query, timestamps = check_timestamps(query)
if len(timestamps) > 0:
# retrieve by timestamps
relevant_segments_metadatas = retrieve_segments_from_timestamp(metadatas, timestamps)
transcripts = [frame['transcript'] for frame in relevant_segments_metadatas]
captions = [frame['caption'] for frame in relevant_segments_metadatas]
context = ""
for i in range(len(transcripts)):
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
# print(context)
return answer_wrt_timestamp(query, context)
else:
# retrieve by embedding
relevant_segments_metadatas = retrieve_by_embedding(index, video_path, query, model_stack[0])
# check relevance
actual_relevant_context, relevant_timestamps = check_relevance(query, relevant_segments_metadatas)
# relevant_timestamps = [frame['start_time'] for frame in relevant_segments_metadatas]
# print(actual_relevant_context)
# generate
answer = generate(query, actual_relevant_context, relevant_timestamps)
# print(answer)
# self-reflection
reflect = self_reflection(query, answer, video_summary)
# print("Reflect", reflect)
if reflect.lower() == 'no':
answer = generate(query, f"{actual_relevant_context}\nSummary={video_summary}")
return answer
def retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path):
image_query = Image.open(image_query_path)
_, vision_model, vision_model_processor, _, _ = model_stack
inputs = vision_model_processor(images=image_query, return_tensors="pt")
outputs = vision_model(**inputs)
image_query_embeds = outputs.pooler_output
res = index.query(vector=image_query_embeds.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
metadatas = []
for id_, match_ in enumerate(res['matches']):
result = index.fetch(ids=[match_['id']])
# Extract the vector data
vector_data = result.vectors.get(match_['id'], {})
# Extract metadata
metadata = vector_data.metadata
metadatas.append(metadata)
return metadatas
def answer_image_question(index, model_stack, metadatas, video_summary:str, video_path:str, query:str, image_query_path:str=None):
# search segment by image
relevant_segments = retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path)
# generate answer using those segments
return generate_w_image(query, image_query_path, relevant_segments)
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def generate_w_image(query:str, image_query_path:str, relevant_metadatas):
base64_image = encode_image(image_query_path)
transcripts = [frame['transcript'] for frame in relevant_metadatas]
captions = [frame['caption'] for frame in relevant_metadatas]
# timestamps = [frame['start_time'] for frame in relevant_metadatas]
context = ""
for i in range(len(transcripts)):
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
# print(context)
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": [
{"type": "text", "text": f"Here is some context about the image: {context}"}, # Add context here
{"type": "text", "text": "You are a helpful LLM assistant. You are good at answering questions about a video given an image. Given the context surrounding the frames most correlated with the image and image, please answer the question. Question: {query}"},
{"type": "image_url", "image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]}
],
temperature=0.0,
max_tokens=100,
)
response = response.choices[0].message.content
# print(response)
return response |