Spaces:
Build error
Build error
import numpy as np | |
import logging, torch | |
from sklearn.preprocessing import MinMaxScaler | |
from sentence_transformers import CrossEncoder | |
# from FlagEmbedding import FlagReranker | |
class Hybrid_search: | |
def __init__(self, bm25_search, faiss_search, reranker_model_name="BAAI/bge-reranker-v2-gemma", initial_bm25_weight=0.5): | |
self.bm25_search = bm25_search | |
self.faiss_search = faiss_search | |
self.bm25_weight = initial_bm25_weight | |
# self.reranker = FlagReranker(reranker_model_name, use_fp16=True) | |
self.logger = logging.getLogger(__name__) | |
async def advanced_search(self, query, keywords, top_n=5, threshold=0.53, prefixes=None): | |
# Dynamic BM25 weighting | |
self._dynamic_weighting(len(query.split())) | |
keywords = f"{' '.join(keywords)}" | |
self.logger.info(f"Query: {query}") | |
self.logger.info(f"Keywords: {keywords}") | |
# Get BM25 scores and doc_ids | |
bm25_scores, bm25_doc_ids = self._get_bm25_results(keywords, top_n = top_n) | |
# self.logger.info(f"BM25 Scores: {bm25_scores}, BM25 doc_ids: {bm25_doc_ids}") | |
# Get FAISS distances, indices, and doc_ids | |
faiss_distances, faiss_indices, faiss_doc_ids = self._get_faiss_results(query) | |
try: | |
faiss_distances, indices, faiss_doc_ids = self._get_faiss_results(query, top_n = top_n) | |
# for dist, idx, doc_id in zip(faiss_distances, indices, faiss_doc_ids): | |
# print(f"Distance: {dist:.4f}, Index: {idx}, Doc ID: {doc_id}") | |
except Exception as e: | |
self.logger.error(f"Search failed: {str(e)}") | |
# Map doc_ids to scores | |
bm25_scores_dict, faiss_scores_dict = self._map_scores_to_doc_ids( | |
bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_distances | |
) | |
# Create a unified set of doc IDs | |
all_doc_ids = sorted(set(bm25_doc_ids).union(faiss_doc_ids)) | |
# print(f"All doc_ids: {all_doc_ids}, BM25 doc_ids: {bm25_doc_ids}, FAISS doc_ids: {faiss_doc_ids}") | |
# Filter doc_ids based on prefixes | |
filtered_doc_ids = self._filter_doc_ids_by_prefixes(all_doc_ids, prefixes) | |
# self.logger.info(f"Filtered doc_ids: {filtered_doc_ids}") | |
if not filtered_doc_ids: | |
self.logger.info("No documents match the prefixes.") | |
return [] | |
# Prepare score lists | |
filtered_bm25_scores, filtered_faiss_scores = self._get_filtered_scores( | |
filtered_doc_ids, bm25_scores_dict, faiss_scores_dict | |
) | |
# self.logger.info(f"Filtered BM25 scores: {filtered_bm25_scores}") | |
# self.logger.info(f"Filtered FAISS scores: {filtered_faiss_scores}") | |
# Normalize scores | |
bm25_scores_normalized, faiss_scores_normalized = self._normalize_scores( | |
filtered_bm25_scores, filtered_faiss_scores | |
) | |
# Calculate hybrid scores | |
hybrid_scores = self._calculate_hybrid_scores(bm25_scores_normalized, faiss_scores_normalized) | |
# Display hybrid scores | |
for idx, doc_id in enumerate(filtered_doc_ids): | |
print(f"Hybrid Score: {hybrid_scores[idx]:.4f}, Doc ID: {doc_id}") | |
# Apply threshold and get top_n results | |
results = self._get_top_n_results(filtered_doc_ids, hybrid_scores, top_n, threshold) | |
self.logger.info(f"Results before reranking: {results}") | |
# If results exist, apply re-ranking | |
# if results: | |
# re_ranked_results = self._rerank_results(query, results) | |
# self.logger.info(f"Results after reranking: {re_ranked_results}") | |
# return re_ranked_results | |
return results | |
def _dynamic_weighting(self, query_length): | |
if query_length <= 5: | |
self.bm25_weight = 0.7 | |
else: | |
self.bm25_weight = 0.5 | |
self.logger.info(f"Dynamic BM25 weight set to: {self.bm25_weight}") | |
def _get_bm25_results(self, keywords, top_n:int = None): | |
# Get BM25 scores | |
bm25_scores = np.array(self.bm25_search.get_scores(keywords)) | |
bm25_doc_ids = np.array(self.bm25_search.doc_ids) # Assuming doc_ids is a list of document IDs | |
# Log the scores and IDs before filtering | |
# self.logger.info(f"BM25 scores: {bm25_scores}") | |
# self.logger.info(f"BM25 doc_ids: {bm25_doc_ids}") | |
# Get the top k indices based on BM25 scores | |
top_k_indices = np.argsort(bm25_scores)[-top_n:][::-1] | |
# Retrieve top k scores and corresponding document IDs | |
top_k_scores = bm25_scores[top_k_indices] | |
top_k_doc_ids = bm25_doc_ids[top_k_indices] | |
# Return top k scores and document IDs | |
return top_k_scores, top_k_doc_ids | |
def _get_faiss_results(self, query, top_n: int = None) -> tuple[np.ndarray, np.ndarray, list[str]]: | |
try: | |
# If top_k is not specified, use all documents | |
if top_n is None: | |
top_n = len(self.faiss_search.doc_ids) | |
# Use the search's search method which handles the embedding | |
distances, indices = self.faiss_search.search(query, k=top_n) | |
if len(distances) == 0 or len(indices) == 0: | |
# Handle case where FAISS returns empty results | |
self.logger.info("FAISS search returned no results.") | |
return np.array([]), np.array([]), [] | |
# Filter out invalid indices (-1) | |
valid_mask = indices != -1 | |
filtered_distances = distances[valid_mask] | |
filtered_indices = indices[valid_mask] | |
# Map indices to doc_ids | |
doc_ids = [self.faiss_search.doc_ids[idx] for idx in filtered_indices | |
if 0 <= idx < len(self.faiss_search.doc_ids)] | |
# self.logger.info(f"FAISS distances: {filtered_distances}") | |
# self.logger.info(f"FAISS indices: {filtered_indices}") | |
# self.logger.info(f"FAISS doc_ids: {doc_ids}") | |
return filtered_distances, filtered_indices, doc_ids | |
except Exception as e: | |
self.logger.error(f"Error in FAISS search: {str(e)}") | |
raise | |
def _map_scores_to_doc_ids(self, bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_scores): | |
bm25_scores_dict = dict(zip(bm25_doc_ids, bm25_scores)) | |
faiss_scores_dict = dict(zip(faiss_doc_ids, faiss_scores)) | |
# self.logger.info(f"BM25 scores dict: {bm25_scores_dict}") | |
# self.logger.info(f"FAISS scores dict: {faiss_scores_dict}") | |
return bm25_scores_dict, faiss_scores_dict | |
def _filter_doc_ids_by_prefixes(self, all_doc_ids, prefixes): | |
if prefixes: | |
filtered_doc_ids = [ | |
doc_id | |
for doc_id in all_doc_ids | |
if any(doc_id.startswith(prefix) for prefix in prefixes) | |
] | |
else: | |
filtered_doc_ids = list(all_doc_ids) | |
return filtered_doc_ids | |
def _get_filtered_scores(self, filtered_doc_ids, bm25_scores_dict, faiss_scores_dict): | |
# Initialize lists to hold scores in the unified doc ID order | |
bm25_aligned_scores = [] | |
faiss_aligned_scores = [] | |
# Populate aligned score lists, filling missing scores with neutral values | |
for doc_id in filtered_doc_ids: | |
bm25_aligned_scores.append(bm25_scores_dict.get(doc_id, 0)) # Use 0 if not found in BM25 | |
faiss_aligned_scores.append(faiss_scores_dict.get(doc_id, max(faiss_scores_dict.values()) + 1)) # Use a high distance if not found in FAISS | |
# Invert the FAISS scores | |
faiss_aligned_scores = [1 / score if score != 0 else 0 for score in faiss_aligned_scores] | |
return bm25_aligned_scores, faiss_aligned_scores | |
def _normalize_scores(self, filtered_bm25_scores, filtered_faiss_scores): | |
scaler_bm25 = MinMaxScaler() | |
bm25_scores_normalized = self._normalize_array(filtered_bm25_scores, scaler_bm25) | |
scaler_faiss = MinMaxScaler() | |
faiss_scores_normalized = self._normalize_array(filtered_faiss_scores, scaler_faiss) | |
# self.logger.info(f"Normalized BM25 scores: {bm25_scores_normalized}") | |
# self.logger.info(f"Normalized FAISS scores: {faiss_scores_normalized}") | |
return bm25_scores_normalized, faiss_scores_normalized | |
def _normalize_array(self, scores, scaler): | |
scores_array = np.array(scores) | |
if np.ptp(scores_array) > 0: | |
normalized_scores = scaler.fit_transform(scores_array.reshape(-1, 1)).flatten() | |
else: | |
# Handle identical scores with a fallback to uniform 0.5 | |
normalized_scores = np.full_like(scores_array, 0.5, dtype=float) | |
return normalized_scores | |
def _calculate_hybrid_scores(self, bm25_scores_normalized, faiss_scores_normalized): | |
hybrid_scores = self.bm25_weight * bm25_scores_normalized + (1 - self.bm25_weight) * faiss_scores_normalized | |
# self.logger.info(f"Hybrid scores: {hybrid_scores}") | |
return hybrid_scores | |
def _get_top_n_results(self, filtered_doc_ids, hybrid_scores, top_n, threshold): | |
hybrid_scores = np.array(hybrid_scores) | |
threshold_indices = np.where(hybrid_scores >= threshold)[0] | |
if len(threshold_indices) == 0: | |
self.logger.info("No documents meet the threshold.") | |
return [] | |
sorted_indices = threshold_indices[np.argsort(hybrid_scores[threshold_indices])[::-1]] | |
top_indices = sorted_indices[:top_n] | |
results = [(filtered_doc_ids[idx], hybrid_scores[idx]) for idx in top_indices] | |
self.logger.info(f"Top {top_n} results: {results}") | |
return results | |
def _rerank_results(self, query, results): | |
""" | |
Re-rank the retrieved documents using FlagReranker with normalized scores. | |
Parameters: | |
- query (str): The search query. | |
- results (List[Tuple[str, float]]): A list of (doc_id, score) tuples. | |
Returns: | |
- List[Tuple[str, float]]: Re-ranked list of (doc_id, score) tuples with normalized scores. | |
""" | |
# Prepare input for the re-ranker | |
document_texts = [self.bm25_search.get_document(doc_id) for doc_id, _ in results] | |
doc_ids = [doc_id for doc_id, _ in results] | |
# Generate pairwise scores using the FlagReranker | |
rerank_inputs = [[query, doc] for doc in document_texts] | |
with torch.no_grad(): | |
rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True) | |
# rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True) | |
# Combine doc_ids with normalized re-rank scores and sort by scores | |
reranked_results = sorted( | |
zip(doc_ids, rerank_scores), | |
key=lambda x: x[1], | |
reverse=True | |
) | |
# Log and return results | |
# self.logger.info(f"Re-ranked results with normalized scores: {reranked_results}") | |
return reranked_results | |