Spaces:
Running
Running
from PIL import Image | |
import faiss | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.Normalize( | |
mean=(0.485, 0.456, 0.406), | |
std=(0.229, 0.224, 0.225) | |
) | |
]) | |
def get_ft( | |
extractor: torch.nn.Module, | |
img: Image.Image | |
) -> np.ndarray: | |
img = transform(img) | |
ft = extractor(img.unsqueeze(0).to(device)) | |
return ft.detach().cpu().numpy() | |
def get_topk( | |
index: faiss.Index, | |
ft: np.ndarray, | |
topk: int = 10 | |
) -> tuple[np.ndarray, np.ndarray]: | |
""" | |
Get top-k nearest neighbors from the index | |
Args: | |
index: Faiss index | |
ft: Input feature | |
topk: Number of nearest neighbors to return | |
Returns: | |
Tuple of (distances, indices) for top-k matches | |
""" | |
# Search index for nearest neighbors | |
distances, indices = index.search(ft, topk) | |
return distances, indices | |
# EXAMPLE: | |
# image = Image.open('path/to/your/image.jpg') | |
# image = transform(image) | |
# extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') | |
# extractor.eval() | |
# extractor.to(device) | |
# ft = get_ft(...) | |
# indices, distances = ... |