video-qa / app.py
Thao Pham
Change input to upload video
5de0912
raw
history blame
7.81 kB
import gradio as gr
import time
import re
import video_utils
import utils
import embed
import rag
import shutil
import os
import uuid
import numpy as np
import pinecone
from pinecone import Pinecone, ServerlessSpec
from sentence_transformers import SentenceTransformer
from transformers import AutoImageProcessor, AutoModel
from transformers import BlipProcessor, BlipForConditionalGeneration
from dotenv import load_dotenv
load_dotenv() # Load from .env
UPLOAD_FOLDER = 'uploads'
global_video_name = None
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
# init models
TEXT_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
VISION_MODEL_PROCESSOR = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
VISION_MODEL = AutoModel.from_pretrained('facebook/dinov2-small')
VLM_PROCESSOR = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
VLM = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# init index
pc = Pinecone(
api_key=PINECONE_API_KEY
)
# Connect to an index
index_name = "multimodal-minilm"
if index_name not in pc.list_indexes().names():
pc.create_index(index_name, dimension=384, metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
))
INDEX = pc.Index(index_name)
MODEL_STACK = [TEXT_MODEL, VISION_MODEL, VISION_MODEL_PROCESSOR, VLM, VLM_PROCESSOR]
def is_valid_youtube_url(url):
"""
Checks if the given URL is a valid YouTube video URL.
Returns True if valid, False otherwise.
"""
youtube_regex = re.compile(
r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/"
r"(watch\?v=|embed/|v/|shorts/)?([a-zA-Z0-9_-]{11})"
)
match = youtube_regex.match(url)
return bool(match)
def check_exist_before_upsert(index, video_path):
# threshold = len(frames) * 3
threshold = [elem for elem in os.listdir(video_path.split('/')[0]) if elem.endswith('.jpg')]
threshold = len(threshold)*3 # image embeds, caption embeds, transcript embeds
dimension = 384
res = index.query(
vector=[0]*dimension, # Dummy vector (not used for filtering)
top_k=10000, # Set a high value to retrieve as many matches as possible
filter={"video_path": video_path} # Filter by video_path
)
# Count the number of matching vectors
num_existing_vectors = len(res["matches"])
if num_existing_vectors >= threshold:
return True
return False
def chat(message, history):
image_input_path = None
# print(message['files'])
video_name, video_input_path = None, None
if len(message['files']) > 0:
assert len(message['files']) == 1
if message['files'][0].endswith('.jpg'):
image_input_path = message['files'][0]
elif message['files'][0].endswith('.mp4'):
video_input_path = message['files'][0]
video_name = os.path.basename(video_input_path).split('.mp4')[0]
message = message['text']
if history is None:
history = []
if video_name is not None:
# Check metadata
history.append((None, f"βœ… Video uploaded succesfully! Your video's title is {video_name}..."))
yield history
output_folder_path = os.path.join(UPLOAD_FOLDER, video_name)
os.makedirs(output_folder_path, exist_ok=True)
path_to_video = os.path.join(output_folder_path, "video.mp4")
if not os.path.exists(path_to_video):
shutil.move(video_input_path, path_to_video)
history.append((None, "⏳ Transcribing video..."))
yield history
path_to_audio_file = os.path.join(output_folder_path, f"audio.mp3")
if not os.path.exists(path_to_audio_file):
path_to_audio_file = video_utils.extract_audio(path_to_video, output_folder_path)
path_to_generated_transcript = os.path.join(output_folder_path, f"transcript.vtt")
if not os.path.exists(path_to_generated_transcript):
path_to_generated_transcript = video_utils.transcribe_video(path_to_audio_file, output_folder_path)
# extract frames and metadata
metadatas_path = os.path.join(output_folder_path, 'metadatas.json')
if not os.path.exists(metadatas_path):
metadatas = video_utils.extract_and_save_frames_and_metadata(path_to_video=path_to_video,
path_to_transcript=path_to_generated_transcript,
path_to_save_extracted_frames=output_folder_path,
path_to_save_metadatas=output_folder_path)
history.append((None, "⏳ Captioning video..."))
yield history
caption_path = os.path.join(output_folder_path, 'captions.json')
if not os.path.exists(caption_path):
video_frames = [os.path.join(output_folder_path, elem) for elem in os.listdir(output_folder_path) if elem.endswith('.jpg')]
metadata_path = video_utils.get_video_caption(video_frames, metadatas, output_folder_path, vlm=VLM, vlm_processor=VLM_PROCESSOR)
history.append((None, "⏳ Indexing..."))
yield history
index_exist = check_exist_before_upsert(INDEX, path_to_video)
print(index_exist)
if not index_exist:
embed.indexing(INDEX, MODEL_STACK, metadatas_path)
# summarizing video
video_summary = rag.summarize_video(metadatas_path)
with open(os.path.join(output_folder_path, "summary.txt"), "w") as f:
f.write(video_summary)
history.append((None, f"Video processing complete! You can now ask me questions about the video {video_name}!"))
yield history
global global_video_name
global_video_name = video_name
else:
history.append((message, None))
yield history
if global_video_name is None:
history.append((None, "You need to upload a video before asking questions."))
yield history
return
output_folder_path = f"{UPLOAD_FOLDER}/{video_name}"
metadatas_path = os.path.join(output_folder_path, 'metadatas.json')
video_summary = ''
with open(f'./{output_folder_path}/summary.txt') as f:
while True:
ln = f.readline()
if ln == '':
break
video_summary += ln.strip()
video_path = os.path.join(output_folder_path, 'video.mp4')
answer = rag.answer_question(INDEX, MODEL_STACK, metadatas_path, video_summary, video_path, message, image_input_path)
history.append((None, answer))
yield history
def clear_chat(history):
history = []
history.append((None, "Please upload a video to get started!"))
return history
def main():
initial_messages = [(None, "Please upload a video to get started!")]
with gr.Blocks() as demo:
chatbot = gr.Chatbot(value=initial_messages)
msg = gr.MultimodalTextbox(file_types=['image', '.mp4'], sources=['upload'])
with gr.Row():
with gr.Column():
submit = gr.Button("Send")
submit.click(chat, [msg, chatbot], chatbot)
with gr.Column():
clear = gr.Button("Clear") # Clear button
# Clear chat history when clear button is clicked
clear.click(clear_chat, [], chatbot)
global video_name
video_name = None
demo.launch()
if __name__ == "__main__":
main()