Bastien Dechamps
[ADD] First demo with 2500 images
d78053e
raw
history blame
1.24 kB
from typing import Dict, List
import numpy as np
import faiss
class Retriever:
def __init__(self, embeddings_path: str, n_neighbors: int = 5):
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
self.n_neighbors = n_neighbors
# Keep track of image names
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
self.index_to_image = {i: image_name for i, image_name in enumerate(self.embeddings.keys())}
# Build Faiss index
self.embeddings = np.array(list(self.embeddings.values()))
self.dim = self.embeddings.shape[1]
self.index = faiss.IndexFlatL2(self.dim)
self.index.add(self.embeddings)
@staticmethod
def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]:
"""Load embeddings from a .npy file
"""
return np.load(embeddings_path, allow_pickle=True).item()
def retrieve(self, queries: np.ndarray) -> List[List[str]]:
"""Retrieve nearest neighbors indexes from queries
"""
_, indexes = self.index.search(queries, self.n_neighbors)
return [[self.index_to_image[i] for i in index] for index in indexes]