from PIL import Image import faiss import numpy as np import torch from torchvision import transforms device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(224), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ) ]) def get_ft( extractor: torch.nn.Module, img: Image.Image ) -> np.ndarray: img = transform(img) ft = extractor(img.unsqueeze(0).to(device)) return ft.detach().cpu().numpy() def get_topk( index: faiss.Index, ft: np.ndarray, topk: int = 10 ) -> tuple[np.ndarray, np.ndarray]: """ Get top-k nearest neighbors from the index Args: index: Faiss index ft: Input feature topk: Number of nearest neighbors to return Returns: Tuple of (distances, indices) for top-k matches """ # Search index for nearest neighbors distances, indices = index.search(ft, topk) return distances, indices # EXAMPLE: # image = Image.open('path/to/your/image.jpg') # image = transform(image) # extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') # extractor.eval() # extractor.to(device) # ft = get_ft(...) # indices, distances = ...