File size: 11,028 Bytes
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c1af0
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import numpy as np
import logging, torch
from sklearn.preprocessing import MinMaxScaler
from sentence_transformers import CrossEncoder
# from FlagEmbedding import FlagReranker


class Hybrid_search:
    def __init__(self, bm25_search, faiss_search, reranker_model_name="BAAI/bge-reranker-v2-gemma", initial_bm25_weight=0.5):
        self.bm25_search = bm25_search
        self.faiss_search = faiss_search
        self.bm25_weight = initial_bm25_weight
        # self.reranker = FlagReranker(reranker_model_name, use_fp16=True)
        self.logger = logging.getLogger(__name__)

    async def advanced_search(self, query, keywords, top_n=5, threshold=0.53, prefixes=None):
        # Dynamic BM25 weighting
        self._dynamic_weighting(len(query.split()))
        keywords = f"{' '.join(keywords)}"
        self.logger.info(f"Query: {query}")
        self.logger.info(f"Keywords: {keywords}")

        # Get BM25 scores and doc_ids
        bm25_scores, bm25_doc_ids = self._get_bm25_results(keywords, top_n = top_n)
        # self.logger.info(f"BM25 Scores: {bm25_scores}, BM25 doc_ids: {bm25_doc_ids}")
        # Get FAISS distances, indices, and doc_ids
        faiss_distances, faiss_indices, faiss_doc_ids = self._get_faiss_results(query)
        try:
            faiss_distances, indices, faiss_doc_ids = self._get_faiss_results(query, top_n = top_n)
            # for dist, idx, doc_id in zip(faiss_distances, indices, faiss_doc_ids):
            #     print(f"Distance: {dist:.4f}, Index: {idx}, Doc ID: {doc_id}")
        except Exception as e:
            self.logger.error(f"Search failed: {str(e)}")
        # Map doc_ids to scores
        bm25_scores_dict, faiss_scores_dict = self._map_scores_to_doc_ids(
            bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_distances
        )
        # Create a unified set of doc IDs
        all_doc_ids = sorted(set(bm25_doc_ids).union(faiss_doc_ids))
        # print(f"All doc_ids: {all_doc_ids}, BM25 doc_ids: {bm25_doc_ids}, FAISS doc_ids: {faiss_doc_ids}")

        # Filter doc_ids based on prefixes
        filtered_doc_ids = self._filter_doc_ids_by_prefixes(all_doc_ids, prefixes)
        # self.logger.info(f"Filtered doc_ids: {filtered_doc_ids}")

        if not filtered_doc_ids:
            self.logger.info("No documents match the prefixes.")
            return []

        # Prepare score lists
        filtered_bm25_scores, filtered_faiss_scores = self._get_filtered_scores(
            filtered_doc_ids, bm25_scores_dict, faiss_scores_dict
        )
        # self.logger.info(f"Filtered BM25 scores: {filtered_bm25_scores}")
        # self.logger.info(f"Filtered FAISS scores: {filtered_faiss_scores}")
       

        # Normalize scores
        bm25_scores_normalized, faiss_scores_normalized = self._normalize_scores(
            filtered_bm25_scores, filtered_faiss_scores
        )

        # Calculate hybrid scores
        hybrid_scores = self._calculate_hybrid_scores(bm25_scores_normalized, faiss_scores_normalized)

        # Display hybrid scores
        for idx, doc_id in enumerate(filtered_doc_ids):
            print(f"Hybrid Score: {hybrid_scores[idx]:.4f}, Doc ID: {doc_id}")

        # Apply threshold and get top_n results
        results = self._get_top_n_results(filtered_doc_ids, hybrid_scores, top_n, threshold)
        self.logger.info(f"Results before reranking: {results}")

        # If results exist, apply re-ranking
        # if results:
        #     re_ranked_results = self._rerank_results(query, results)
        #     self.logger.info(f"Results after reranking: {re_ranked_results}")
        #     return re_ranked_results

        return results
    

    def _dynamic_weighting(self, query_length):
        if query_length <= 5:
            self.bm25_weight = 0.7
        else:
            self.bm25_weight = 0.5
        self.logger.info(f"Dynamic BM25 weight set to: {self.bm25_weight}")

    def _get_bm25_results(self, keywords, top_n:int = None):
        # Get BM25 scores
        bm25_scores = np.array(self.bm25_search.get_scores(keywords))
        bm25_doc_ids = np.array(self.bm25_search.doc_ids)  # Assuming doc_ids is a list of document IDs

        # Log the scores and IDs before filtering
        # self.logger.info(f"BM25 scores: {bm25_scores}")
        # self.logger.info(f"BM25 doc_ids: {bm25_doc_ids}")

        # Get the top k indices based on BM25 scores
        top_k_indices = np.argsort(bm25_scores)[-top_n:][::-1]

        # Retrieve top k scores and corresponding document IDs
        top_k_scores = bm25_scores[top_k_indices]
        top_k_doc_ids = bm25_doc_ids[top_k_indices]

        # Return top k scores and document IDs
        return top_k_scores, top_k_doc_ids

    def _get_faiss_results(self, query, top_n: int = None) -> tuple[np.ndarray, np.ndarray, list[str]]:
    
        try:
            # If top_k is not specified, use all documents
            if top_n is None:
                top_n = len(self.faiss_search.doc_ids)
                
            # Use the search's search method which handles the embedding
            distances, indices = self.faiss_search.search(query, k=top_n) 
            
            if len(distances) == 0 or len(indices) == 0:
                # Handle case where FAISS returns empty results
                self.logger.info("FAISS search returned no results.")
                return np.array([]), np.array([]), []
            
            # Filter out invalid indices (-1)
            valid_mask = indices != -1
            filtered_distances = distances[valid_mask]
            filtered_indices = indices[valid_mask]
            
            # Map indices to doc_ids
            doc_ids = [self.faiss_search.doc_ids[idx] for idx in filtered_indices 
                    if 0 <= idx < len(self.faiss_search.doc_ids)]
            
            # self.logger.info(f"FAISS distances: {filtered_distances}")
            # self.logger.info(f"FAISS indices: {filtered_indices}")
            # self.logger.info(f"FAISS doc_ids: {doc_ids}")
            
            return filtered_distances, filtered_indices, doc_ids
            
        except Exception as e:
            self.logger.error(f"Error in FAISS search: {str(e)}")
            raise

    def _map_scores_to_doc_ids(self, bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_scores):
        bm25_scores_dict = dict(zip(bm25_doc_ids, bm25_scores))
        faiss_scores_dict = dict(zip(faiss_doc_ids, faiss_scores))
        # self.logger.info(f"BM25 scores dict: {bm25_scores_dict}")
        # self.logger.info(f"FAISS scores dict: {faiss_scores_dict}")
        return bm25_scores_dict, faiss_scores_dict

    def _filter_doc_ids_by_prefixes(self, all_doc_ids, prefixes):
        if prefixes:
            filtered_doc_ids = [
                doc_id
                for doc_id in all_doc_ids
                if any(doc_id.startswith(prefix) for prefix in prefixes)
            ]
        else:
            filtered_doc_ids = list(all_doc_ids)
        return filtered_doc_ids

    def _get_filtered_scores(self, filtered_doc_ids, bm25_scores_dict, faiss_scores_dict):
         # Initialize lists to hold scores in the unified doc ID order
        bm25_aligned_scores = []
        faiss_aligned_scores = []

        # Populate aligned score lists, filling missing scores with neutral values
        for doc_id in filtered_doc_ids:
            bm25_aligned_scores.append(bm25_scores_dict.get(doc_id, 0))  # Use 0 if not found in BM25
            faiss_aligned_scores.append(faiss_scores_dict.get(doc_id, max(faiss_scores_dict.values()) + 1))  # Use a high distance if not found in FAISS
        
        # Invert the FAISS scores
        faiss_aligned_scores = [1 / score if score != 0 else 0 for score in faiss_aligned_scores]

        return bm25_aligned_scores, faiss_aligned_scores

    def _normalize_scores(self, filtered_bm25_scores, filtered_faiss_scores):
        scaler_bm25 = MinMaxScaler()
        bm25_scores_normalized = self._normalize_array(filtered_bm25_scores, scaler_bm25)

        scaler_faiss = MinMaxScaler()
        faiss_scores_normalized = self._normalize_array(filtered_faiss_scores, scaler_faiss)

        # self.logger.info(f"Normalized BM25 scores: {bm25_scores_normalized}")
        # self.logger.info(f"Normalized FAISS scores: {faiss_scores_normalized}")
        return bm25_scores_normalized, faiss_scores_normalized

    def _normalize_array(self, scores, scaler):
        scores_array = np.array(scores)
        if np.ptp(scores_array) > 0:
            normalized_scores = scaler.fit_transform(scores_array.reshape(-1, 1)).flatten()
        else:
            # Handle identical scores with a fallback to uniform 0.5
            normalized_scores = np.full_like(scores_array, 0.5, dtype=float)
        return normalized_scores

    def _calculate_hybrid_scores(self, bm25_scores_normalized, faiss_scores_normalized):
        hybrid_scores = self.bm25_weight * bm25_scores_normalized + (1 - self.bm25_weight) * faiss_scores_normalized
        # self.logger.info(f"Hybrid scores: {hybrid_scores}")
        return hybrid_scores

    def _get_top_n_results(self, filtered_doc_ids, hybrid_scores, top_n, threshold):
        hybrid_scores = np.array(hybrid_scores)
        threshold_indices = np.where(hybrid_scores >= threshold)[0]
        if len(threshold_indices) == 0:
            self.logger.info("No documents meet the threshold.")
            return []

        sorted_indices = threshold_indices[np.argsort(hybrid_scores[threshold_indices])[::-1]]
        top_indices = sorted_indices[:top_n]

        results = [(filtered_doc_ids[idx], hybrid_scores[idx]) for idx in top_indices]
        self.logger.info(f"Top {top_n} results: {results}")
        return results

    def _rerank_results(self, query, results):
        """
        Re-rank the retrieved documents using FlagReranker with normalized scores.

        Parameters:
        - query (str): The search query.
        - results (List[Tuple[str, float]]): A list of (doc_id, score) tuples.

        Returns:
        - List[Tuple[str, float]]: Re-ranked list of (doc_id, score) tuples with normalized scores.
        """
        # Prepare input for the re-ranker
        document_texts = [self.bm25_search.get_document(doc_id) for doc_id, _ in results]
        doc_ids = [doc_id for doc_id, _ in results]

        # Generate pairwise scores using the FlagReranker
        rerank_inputs = [[query, doc] for doc in document_texts]
        with torch.no_grad():
            rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)

        # rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)

        # Combine doc_ids with normalized re-rank scores and sort by scores
        reranked_results = sorted(
            zip(doc_ids, rerank_scores),
            key=lambda x: x[1],
            reverse=True
        )

        # Log and return results
        # self.logger.info(f"Re-ranked results with normalized scores: {reranked_results}")
        return reranked_results