File size: 11,849 Bytes
d50ce1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from PIL import Image 
import os
from utils import load_json_file, str2time
from openai import OpenAI
import base64

def get_smallest_timestamp(timestamps):
    assert len(timestamps) > 0

    timestamps_in_ms = [str2time(elem) for elem in timestamps]

    smallest_timestamp_in_ms = timestamps_in_ms[0]
    smallest_timestamp = timestamps[0]
    for i, elem in enumerate(timestamps_in_ms):
        if elem < smallest_timestamp_in_ms:
            smallest_timestamp_in_ms = elem
            smallest_timestamp = timestamps[i]
    return smallest_timestamp

def generate(query, context, relevant_timestamps=None):
    prompt = PromptTemplate(input_variables=["question", "context"], template="You're a helpful LLM assistant in answering questions regarding a video. Given contexts are segments relevant to the question, please answer the question. Do not refer to segments. Context: {context}, question: {question} \nA:")

    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query, context=context)

    if relevant_timestamps is not None and len(relevant_timestamps)>0:
        # get smallest timestamp = earliest mention
        smallest_timestamp = get_smallest_timestamp(relevant_timestamps)
        response += f' {smallest_timestamp}'
    return response


def check_relevance(query, relevant_metadatas):
    transcripts = [frame['transcript'] for frame in relevant_metadatas]
    captions = [frame['caption'] for frame in relevant_metadatas]
    timestamps = [frame['start_time'] for frame in relevant_metadatas]

    context = ""
    for i in range(len(transcripts)):
        context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
    # print(context)

    prompt = PromptTemplate(input_variables=["question", "context"], template="""
    You are a grader assessing relevance of a retrieved video segment to a user question. \n 
    If the video segment contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the video segment is relevant to the question. \n
    Answer in a string, separated by commas. For example: if there are segments provided, answer: yes,no,no,yes. \n
    Question: {question} Context: {context}\n A:""")

    # query = "What are the books mentioned in the video?"
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query, context=context)
    # print(response)

    relevance_response = response.split(',')

    actual_relevant_context = ""
    relevant_timestamps = []
    for i, relevance_check in enumerate(relevance_response):
        if relevance_check.strip() == 'yes':
            actual_relevant_context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
            relevant_timestamps.append(timestamps[i])
    return actual_relevant_context, relevant_timestamps


def retrieve_segments_from_timestamp(metadatas, timestamps):
    relevant_segments = []

    for timestamp in timestamps:
        time_to_find_ms = str2time(timestamp)
        buffer = 5000 # 5 seconds before and after

        for segment in metadatas:
            start = str2time(segment['start_time'])
            end = str2time(segment['end_time'])
            if start <= time_to_find_ms + buffer and end >= time_to_find_ms - buffer:
                relevant_segments.append(segment)

    return relevant_segments


def check_timestamps(query):
    prompt = PromptTemplate(input_variables=["question"], template="You're a helpful LLM assistant. You're good at detecting any timestamps provided in a query. Please detect the question and timestamp in the the following question and separated them by commas such as question,timestamp1,timestamp2 if timestamps are provided else just question. Question: {question} \nA:")

    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query)

    timestamps = []
    if len(response.split(',')) > 1:
        query = response.split(',')[0].strip()
        timestamps = [f"00:{elem.strip()}.00" for elem in response.split(',')[1:]]

    return query, timestamps

def retrieve_by_embedding(index, video_path, query, text_model):
    print(query)
    query_embedding = text_model.encode(query)

    res = index.query(vector=query_embedding.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )

    metadatas = []
    for id, match_ in enumerate(res['matches']):
        result = index.fetch(ids=[match_['id']])
        
        # Extract the vector data
        vector_data = result.vectors.get(match_['id'], {})

        # Extract metadata
        metadata = vector_data.metadata
        metadatas.append(metadata)

    return metadatas

def self_reflection(query, answer, summary):
    prompt = PromptTemplate(input_variables=["summary", "question", "answer"], template="You're a helpful LLM assistant. You're good at determining if the provided answer is satisfactory to a question relating to a video. You have access to the video summary as follows: {summary}. Given a pair of question and answer, give the answer's satisfactory score in either yes or no. Question: {question}, Answer: {answer} \nA:")
 
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(summary=summary, question=query, answer=answer)
    return response


def get_full_transcript(metadatas):
    # metadatas = webvtt.read(path_to_transcript)
    transcripts = [frame['transcript'] for frame in metadatas]

    full_text = ''
    for idx, transcript in enumerate(transcripts):
        text = transcript.strip().replace("  ", " ")
        full_text += f"{text} "

    full_text = full_text.strip()
    return full_text

def summarize_video(metadatas_path:str):
    metadatas = load_json_file(metadatas_path)

    # get full transcript 
    transcript = get_full_transcript(metadatas)
    prompt = PromptTemplate(input_variables=["transcript"], template="You're a helpful LLM assistant. Please provide a summary for the video given its full transcript: {transcript} \nA:")
 
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(transcript=transcript)
    return response

def answer_wrt_timestamp(query, context):
    prompt = PromptTemplate(input_variables=["question", "context"], template="""
    You're a helpful LLM assistant. Given a question and a timestamp, I have retrieved the relevant context as follows. Please answer the question using the information provided in the context. Question: {question}, context: {context} \n
    For example: Question="What happens at 4:20?" Caption="a person is standing up" Transcript="I have to go" Appropriate Answer="At 4:20, a person is standing up and saying he has to go."
    A:""")
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query, context=context)
    return response


def answer_question(index, model_stack, metadatas_path, video_summary:str, video_path:str, query:str, image_input_path:str=None):
    metadatas = load_json_file(metadatas_path)
    if image_input_path is not None:
        return answer_image_question(index, model_stack, metadatas, video_summary, video_path, query, image_input_path)

    # check if timestamp provided 
    query, timestamps = check_timestamps(query)

    if len(timestamps) > 0:
        # retrieve by timestamps 
        relevant_segments_metadatas = retrieve_segments_from_timestamp(metadatas, timestamps)
        transcripts = [frame['transcript'] for frame in relevant_segments_metadatas]
        captions = [frame['caption'] for frame in relevant_segments_metadatas]
        context = ""
        for i in range(len(transcripts)):
            context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
        # print(context)
        return answer_wrt_timestamp(query, context)
    else:
        # retrieve by embedding 
        relevant_segments_metadatas = retrieve_by_embedding(index, video_path, query, model_stack[0])

    # check relevance 
    actual_relevant_context, relevant_timestamps = check_relevance(query, relevant_segments_metadatas)
    # relevant_timestamps = [frame['start_time'] for frame in relevant_segments_metadatas]
    # print(actual_relevant_context)

    # generate 
    answer = generate(query, actual_relevant_context, relevant_timestamps)
    # print(answer)

    # self-reflection 
    reflect = self_reflection(query, answer, video_summary)

    # print("Reflect", reflect)
    if reflect.lower() == 'no':
        answer = generate(query, f"{actual_relevant_context}\nSummary={video_summary}")
    
    return answer

def retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path):
    image_query = Image.open(image_query_path)
    _, vision_model, vision_model_processor, _, _ = model_stack
    inputs = vision_model_processor(images=image_query, return_tensors="pt")
    outputs = vision_model(**inputs)
    image_query_embeds = outputs.pooler_output

    res = index.query(vector=image_query_embeds.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
    
    metadatas = []
    for id_, match_ in enumerate(res['matches']):
        result = index.fetch(ids=[match_['id']])
        
        # Extract the vector data
        vector_data = result.vectors.get(match_['id'], {})

        # Extract metadata
        metadata = vector_data.metadata
        metadatas.append(metadata)

    return metadatas


def answer_image_question(index, model_stack, metadatas, video_summary:str, video_path:str, query:str, image_query_path:str=None):
    # search segment by image 
    relevant_segments = retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path)

    # generate answer using those segments 
    return generate_w_image(query, image_query_path, relevant_segments)


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def generate_w_image(query:str, image_query_path:str, relevant_metadatas):
    base64_image = encode_image(image_query_path)
    transcripts = [frame['transcript'] for frame in relevant_metadatas]
    captions = [frame['caption'] for frame in relevant_metadatas]
    # timestamps = [frame['start_time'] for frame in relevant_metadatas]

    context = ""
    for i in range(len(transcripts)):
        context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
    # print(context)

    client = OpenAI()
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "user", "content": [
                {"type": "text", "text": f"Here is some context about the image: {context}"},  # Add context here
                {"type": "text", "text": "You are a helpful LLM assistant. You are good at answering questions about a video given an image. Given the context surrounding the frames most correlated with the image and image, please answer the question. Question: {query}"},
                {"type": "image_url", "image_url": {
                    "url": f"data:image/png;base64,{base64_image}"
                    }
                }
            ]}
        ],
        temperature=0.0,
        max_tokens=100,
    )

    response = response.choices[0].message.content
    # print(response)
    return response