|
import torch |
|
import numpy as np |
|
from transformers import XCLIPProcessor, XCLIPModel |
|
from decord import VideoReader, cpu |
|
|
|
class XCLIPEmbedder: |
|
def __init__(self, model_name="microsoft/xclip-large-patch14"): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = XCLIPModel.from_pretrained(model_name).to(self.device) |
|
self.processor = XCLIPProcessor.from_pretrained(model_name) |
|
|
|
def embed_video(self, video_path): |
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
frame_indices = np.linspace(0, len(vr) - 1, num=8, dtype=int) |
|
video_frames = vr.get_batch(frame_indices).asnumpy() |
|
|
|
inputs = self.processor(videos=list(video_frames), return_tensors="pt", padding=True).to(self.device) |
|
with torch.no_grad(): |
|
video_features = self.model.get_video_features(**inputs).squeeze(0).cpu().numpy() |
|
return video_features / np.linalg.norm(video_features) |
|
|
|
def embed_text(self, text): |
|
inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device) |
|
with torch.no_grad(): |
|
text_features = self.model.get_text_features(**inputs).squeeze(0).cpu().numpy() |
|
return text_features / np.linalg.norm(text_features) |
|
|