File size: 1,735 Bytes
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import logging
from typing import List
import numpy as np
import pandas as pd
import faiss

from common.constants import COLUMN_EMBEDDING
from common.constants import DO_NORMALIZATION
from common.configuration import DataBaseConfiguration
from components.embedding_extraction import EmbeddingExtractor

logger = logging.getLogger(__name__)


class FaissVectorSearch:
    def __init__(
        self, model: EmbeddingExtractor, df: pd.DataFrame, config: DataBaseConfiguration
    ):
        self.model = model
        self.config = config
        self.path_to_metadata = config.faiss.path_to_metadata
        if self.config.ranker.use_ranging:
            self.k_neighbors = config.ranker.k_neighbors
        else:
            self.k_neighbors = config.search.vector_search.k_neighbors
        self.__create_index(df)

    def __create_index(self, df: pd.DataFrame):
        """Load the metadata file."""
        if len(df) == 0:
            self.index = None
            return
        df = df.where(pd.notna(df), None)
        embeddings = np.array(df[COLUMN_EMBEDDING].tolist())
        dim = embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dim)
        self.index.add(embeddings)

    def search_vectors(self, query: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Поиск векторов в индексе.
        """
        logger.info(f"Searching vectors in index for query: {query}")
        if self.index is None:
            return (np.array([]), np.array([]), np.array([]))
        query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
        scores, indexes = self.index.search(query_embeds, self.k_neighbors)
        return query_embeds[0], scores[0], indexes[0]