|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoImageProcessor, AutoModel |
|
from tqdm import tqdm |
|
from PIL import Image |
|
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union |
|
import numpy as np |
|
import uuid |
|
from utils import load_json_file |
|
|
|
def embed_texts(text_ls:List[str], text_model=None): |
|
if text_model is None: |
|
text_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
text_embeddings = [] |
|
for i, text in enumerate(tqdm(text_ls, desc="Embedding text")): |
|
embeds = text_model.encode(text) |
|
text_embeddings.append(embeds) |
|
return np.array(text_embeddings) |
|
|
|
|
|
def embed_images(image_path_ls:List[str], vision_model=None, vision_model_processor=None): |
|
if vision_model is None or vision_model_processor is None: |
|
vision_model_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small') |
|
vision_model = AutoModel.from_pretrained('facebook/dinov2-small') |
|
|
|
image_embeds_ls = [] |
|
for i, frame in enumerate(tqdm(image_path_ls, desc="Embedding image")): |
|
frame = Image.open(frame) |
|
|
|
inputs = vision_model_processor(images=frame, return_tensors="pt") |
|
outputs = vision_model(**inputs) |
|
image_embeds_ls.append(outputs.pooler_output) |
|
return np.array([elem.squeeze().detach().numpy() for elem in image_embeds_ls]) |
|
|
|
|
|
def indexing(index, model_stack, vid_metadata_path): |
|
text_model, vision_model, vision_model_processor, _, _ = model_stack |
|
|
|
|
|
vid_metadata = load_json_file(vid_metadata_path) |
|
|
|
|
|
vid_trans = [frame['transcript'] for frame in vid_metadata] |
|
transcript_embeddings = embed_texts(text_ls=vid_trans, text_model=text_model) |
|
|
|
|
|
vid_captions = [frame['caption'] for frame in vid_metadata] |
|
caption_embeddings = embed_texts(text_ls=vid_captions, text_model=text_model) |
|
|
|
|
|
vid_img_paths = [vid['extracted_frame_path'] for vid in vid_metadata] |
|
frame_embeddings = embed_images(vid_img_paths, vision_model, vision_model_processor) |
|
|
|
for ls in [transcript_embeddings, caption_embeddings, frame_embeddings]: |
|
|
|
vectors = [ |
|
(str(uuid.uuid4()), emb.tolist(), meta) |
|
for emb, meta in zip(ls, vid_metadata) |
|
] |
|
|
|
index.upsert(vectors) |
|
|