import numpy as np import torch import torchvision.transforms as T from PIL import Image from .abstract_embedder import AbstractImageEmbedder class DinoV2Embedder(AbstractImageEmbedder): def __init__(self, device: str = "cpu"): """Embedder using DINOv2 embeddings. """ super().__init__(device) self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device) self.model.eval() self.transforms = T.Compose([ T.Resize((256, 256), interpolation=T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def embed(self, image: Image) -> np.ndarray: image = image.convert("RGB") image = self.transforms(image).unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(image)[0].cpu().numpy() return output