video-qa / embed.py
Thao Pham
First commit
d50ce1c
raw
history blame
2.5 kB
from sentence_transformers import SentenceTransformer
from transformers import AutoImageProcessor, AutoModel
from tqdm import tqdm
from PIL import Image
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
import numpy as np
import uuid
from utils import load_json_file
def embed_texts(text_ls:List[str], text_model=None):
if text_model is None:
text_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
text_embeddings = []
for i, text in enumerate(tqdm(text_ls, desc="Embedding text")):
embeds = text_model.encode(text)
text_embeddings.append(embeds)
return np.array(text_embeddings)
def embed_images(image_path_ls:List[str], vision_model=None, vision_model_processor=None):
if vision_model is None or vision_model_processor is None:
vision_model_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
vision_model = AutoModel.from_pretrained('facebook/dinov2-small')
image_embeds_ls = []
for i, frame in enumerate(tqdm(image_path_ls, desc="Embedding image")):
frame = Image.open(frame)
# TODO: add device here
inputs = vision_model_processor(images=frame, return_tensors="pt")
outputs = vision_model(**inputs)
image_embeds_ls.append(outputs.pooler_output)
return np.array([elem.squeeze().detach().numpy() for elem in image_embeds_ls])
def indexing(index, model_stack, vid_metadata_path):
text_model, vision_model, vision_model_processor, _, _ = model_stack
# read metadata file
vid_metadata = load_json_file(vid_metadata_path)
# embed transcripts
vid_trans = [frame['transcript'] for frame in vid_metadata]
transcript_embeddings = embed_texts(text_ls=vid_trans, text_model=text_model)
# embed caption
vid_captions = [frame['caption'] for frame in vid_metadata]
caption_embeddings = embed_texts(text_ls=vid_captions, text_model=text_model)
# embed frames
vid_img_paths = [vid['extracted_frame_path'] for vid in vid_metadata]
frame_embeddings = embed_images(vid_img_paths, vision_model, vision_model_processor)
for ls in [transcript_embeddings, caption_embeddings, frame_embeddings]:
# Prepare metadata
vectors = [
(str(uuid.uuid4()), emb.tolist(), meta) # Generate unique IDs
for emb, meta in zip(ls, vid_metadata)
]
# Upsert vectors into Pinecone
index.upsert(vectors)