geoguessr-bot / geoguessr_bot /retriever /dino_v2_embedder.py
Bastien Dechamps
dinoV2
077dc3f
raw
history blame
1.02 kB
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from .abstract_embedder import AbstractImageEmbedder
class DinoV2Embedder(AbstractImageEmbedder):
def __init__(self, device: str = "cpu"):
"""Embedder using DINOv2 embeddings.
"""
super().__init__(device)
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device)
self.model.eval()
self.transforms = T.Compose([
T.Resize((256, 256), interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def embed(self, image: Image) -> np.ndarray:
image = image.convert("RGB")
image = self.transforms(image).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(image)[0].cpu().numpy()
return output