File size: 1,245 Bytes
1791df2
944c93a
 
 
 
 
4388025
1791df2
4388025
 
 
 
 
 
 
 
 
 
 
 
 
 
077dc3f
4388025
077dc3f
4388025
1791df2
4388025
 
1791df2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Dict, List, Tuple

import numpy as np
import faiss


class Retriever:
    def __init__(self, embeddings_path: str):
        self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)

        # Keep track of image names
        self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
        self.index_to_image = {i: image_name for i, image_name in enumerate(self.embeddings.keys())}

        # Build Faiss index
        self.embeddings = np.array(list(self.embeddings.values()))
        self.dim = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(self.dim)
        self.index.add(self.embeddings)

    @staticmethod
    def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]:
        """Load embeddings from a .npy file
        """
        return np.load(embeddings_path, allow_pickle=True).item()

    def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> Tuple[List[List[str]], List[List[float]]]:
        """Retrieve nearest neighbors indexes from queries
        """
        distances, indexes = self.index.search(queries, n_neighbors)
        return [[self.index_to_image[i] for i in index] for index in indexes], distances