blip / src /embedder.py
amezi's picture
changing to x-clip
fb3036a
raw
history blame
1.27 kB
import torch
import numpy as np
from transformers import XCLIPProcessor, XCLIPModel
from decord import VideoReader, cpu
class XCLIPEmbedder:
def __init__(self, model_name="microsoft/xclip-large-patch14"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = XCLIPModel.from_pretrained(model_name).to(self.device)
self.processor = XCLIPProcessor.from_pretrained(model_name)
def embed_video(self, video_path):
vr = VideoReader(video_path, ctx=cpu(0))
frame_indices = np.linspace(0, len(vr) - 1, num=8, dtype=int)
video_frames = vr.get_batch(frame_indices).asnumpy()
inputs = self.processor(videos=list(video_frames), return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
video_features = self.model.get_video_features(**inputs).squeeze(0).cpu().numpy()
return video_features / np.linalg.norm(video_features)
def embed_text(self, text):
inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
text_features = self.model.get_text_features(**inputs).squeeze(0).cpu().numpy()
return text_features / np.linalg.norm(text_features)