chat / app /search /hybrid_search.py
ariansyahdedy's picture
Test Rag
e0c1af0
import numpy as np
import logging, torch
from sklearn.preprocessing import MinMaxScaler
from sentence_transformers import CrossEncoder
# from FlagEmbedding import FlagReranker
class Hybrid_search:
def __init__(self, bm25_search, faiss_search, reranker_model_name="BAAI/bge-reranker-v2-gemma", initial_bm25_weight=0.5):
self.bm25_search = bm25_search
self.faiss_search = faiss_search
self.bm25_weight = initial_bm25_weight
# self.reranker = FlagReranker(reranker_model_name, use_fp16=True)
self.logger = logging.getLogger(__name__)
async def advanced_search(self, query, keywords, top_n=5, threshold=0.53, prefixes=None):
# Dynamic BM25 weighting
self._dynamic_weighting(len(query.split()))
keywords = f"{' '.join(keywords)}"
self.logger.info(f"Query: {query}")
self.logger.info(f"Keywords: {keywords}")
# Get BM25 scores and doc_ids
bm25_scores, bm25_doc_ids = self._get_bm25_results(keywords, top_n = top_n)
# self.logger.info(f"BM25 Scores: {bm25_scores}, BM25 doc_ids: {bm25_doc_ids}")
# Get FAISS distances, indices, and doc_ids
faiss_distances, faiss_indices, faiss_doc_ids = self._get_faiss_results(query)
try:
faiss_distances, indices, faiss_doc_ids = self._get_faiss_results(query, top_n = top_n)
# for dist, idx, doc_id in zip(faiss_distances, indices, faiss_doc_ids):
# print(f"Distance: {dist:.4f}, Index: {idx}, Doc ID: {doc_id}")
except Exception as e:
self.logger.error(f"Search failed: {str(e)}")
# Map doc_ids to scores
bm25_scores_dict, faiss_scores_dict = self._map_scores_to_doc_ids(
bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_distances
)
# Create a unified set of doc IDs
all_doc_ids = sorted(set(bm25_doc_ids).union(faiss_doc_ids))
# print(f"All doc_ids: {all_doc_ids}, BM25 doc_ids: {bm25_doc_ids}, FAISS doc_ids: {faiss_doc_ids}")
# Filter doc_ids based on prefixes
filtered_doc_ids = self._filter_doc_ids_by_prefixes(all_doc_ids, prefixes)
# self.logger.info(f"Filtered doc_ids: {filtered_doc_ids}")
if not filtered_doc_ids:
self.logger.info("No documents match the prefixes.")
return []
# Prepare score lists
filtered_bm25_scores, filtered_faiss_scores = self._get_filtered_scores(
filtered_doc_ids, bm25_scores_dict, faiss_scores_dict
)
# self.logger.info(f"Filtered BM25 scores: {filtered_bm25_scores}")
# self.logger.info(f"Filtered FAISS scores: {filtered_faiss_scores}")
# Normalize scores
bm25_scores_normalized, faiss_scores_normalized = self._normalize_scores(
filtered_bm25_scores, filtered_faiss_scores
)
# Calculate hybrid scores
hybrid_scores = self._calculate_hybrid_scores(bm25_scores_normalized, faiss_scores_normalized)
# Display hybrid scores
for idx, doc_id in enumerate(filtered_doc_ids):
print(f"Hybrid Score: {hybrid_scores[idx]:.4f}, Doc ID: {doc_id}")
# Apply threshold and get top_n results
results = self._get_top_n_results(filtered_doc_ids, hybrid_scores, top_n, threshold)
self.logger.info(f"Results before reranking: {results}")
# If results exist, apply re-ranking
# if results:
# re_ranked_results = self._rerank_results(query, results)
# self.logger.info(f"Results after reranking: {re_ranked_results}")
# return re_ranked_results
return results
def _dynamic_weighting(self, query_length):
if query_length <= 5:
self.bm25_weight = 0.7
else:
self.bm25_weight = 0.5
self.logger.info(f"Dynamic BM25 weight set to: {self.bm25_weight}")
def _get_bm25_results(self, keywords, top_n:int = None):
# Get BM25 scores
bm25_scores = np.array(self.bm25_search.get_scores(keywords))
bm25_doc_ids = np.array(self.bm25_search.doc_ids) # Assuming doc_ids is a list of document IDs
# Log the scores and IDs before filtering
# self.logger.info(f"BM25 scores: {bm25_scores}")
# self.logger.info(f"BM25 doc_ids: {bm25_doc_ids}")
# Get the top k indices based on BM25 scores
top_k_indices = np.argsort(bm25_scores)[-top_n:][::-1]
# Retrieve top k scores and corresponding document IDs
top_k_scores = bm25_scores[top_k_indices]
top_k_doc_ids = bm25_doc_ids[top_k_indices]
# Return top k scores and document IDs
return top_k_scores, top_k_doc_ids
def _get_faiss_results(self, query, top_n: int = None) -> tuple[np.ndarray, np.ndarray, list[str]]:
try:
# If top_k is not specified, use all documents
if top_n is None:
top_n = len(self.faiss_search.doc_ids)
# Use the search's search method which handles the embedding
distances, indices = self.faiss_search.search(query, k=top_n)
if len(distances) == 0 or len(indices) == 0:
# Handle case where FAISS returns empty results
self.logger.info("FAISS search returned no results.")
return np.array([]), np.array([]), []
# Filter out invalid indices (-1)
valid_mask = indices != -1
filtered_distances = distances[valid_mask]
filtered_indices = indices[valid_mask]
# Map indices to doc_ids
doc_ids = [self.faiss_search.doc_ids[idx] for idx in filtered_indices
if 0 <= idx < len(self.faiss_search.doc_ids)]
# self.logger.info(f"FAISS distances: {filtered_distances}")
# self.logger.info(f"FAISS indices: {filtered_indices}")
# self.logger.info(f"FAISS doc_ids: {doc_ids}")
return filtered_distances, filtered_indices, doc_ids
except Exception as e:
self.logger.error(f"Error in FAISS search: {str(e)}")
raise
def _map_scores_to_doc_ids(self, bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_scores):
bm25_scores_dict = dict(zip(bm25_doc_ids, bm25_scores))
faiss_scores_dict = dict(zip(faiss_doc_ids, faiss_scores))
# self.logger.info(f"BM25 scores dict: {bm25_scores_dict}")
# self.logger.info(f"FAISS scores dict: {faiss_scores_dict}")
return bm25_scores_dict, faiss_scores_dict
def _filter_doc_ids_by_prefixes(self, all_doc_ids, prefixes):
if prefixes:
filtered_doc_ids = [
doc_id
for doc_id in all_doc_ids
if any(doc_id.startswith(prefix) for prefix in prefixes)
]
else:
filtered_doc_ids = list(all_doc_ids)
return filtered_doc_ids
def _get_filtered_scores(self, filtered_doc_ids, bm25_scores_dict, faiss_scores_dict):
# Initialize lists to hold scores in the unified doc ID order
bm25_aligned_scores = []
faiss_aligned_scores = []
# Populate aligned score lists, filling missing scores with neutral values
for doc_id in filtered_doc_ids:
bm25_aligned_scores.append(bm25_scores_dict.get(doc_id, 0)) # Use 0 if not found in BM25
faiss_aligned_scores.append(faiss_scores_dict.get(doc_id, max(faiss_scores_dict.values()) + 1)) # Use a high distance if not found in FAISS
# Invert the FAISS scores
faiss_aligned_scores = [1 / score if score != 0 else 0 for score in faiss_aligned_scores]
return bm25_aligned_scores, faiss_aligned_scores
def _normalize_scores(self, filtered_bm25_scores, filtered_faiss_scores):
scaler_bm25 = MinMaxScaler()
bm25_scores_normalized = self._normalize_array(filtered_bm25_scores, scaler_bm25)
scaler_faiss = MinMaxScaler()
faiss_scores_normalized = self._normalize_array(filtered_faiss_scores, scaler_faiss)
# self.logger.info(f"Normalized BM25 scores: {bm25_scores_normalized}")
# self.logger.info(f"Normalized FAISS scores: {faiss_scores_normalized}")
return bm25_scores_normalized, faiss_scores_normalized
def _normalize_array(self, scores, scaler):
scores_array = np.array(scores)
if np.ptp(scores_array) > 0:
normalized_scores = scaler.fit_transform(scores_array.reshape(-1, 1)).flatten()
else:
# Handle identical scores with a fallback to uniform 0.5
normalized_scores = np.full_like(scores_array, 0.5, dtype=float)
return normalized_scores
def _calculate_hybrid_scores(self, bm25_scores_normalized, faiss_scores_normalized):
hybrid_scores = self.bm25_weight * bm25_scores_normalized + (1 - self.bm25_weight) * faiss_scores_normalized
# self.logger.info(f"Hybrid scores: {hybrid_scores}")
return hybrid_scores
def _get_top_n_results(self, filtered_doc_ids, hybrid_scores, top_n, threshold):
hybrid_scores = np.array(hybrid_scores)
threshold_indices = np.where(hybrid_scores >= threshold)[0]
if len(threshold_indices) == 0:
self.logger.info("No documents meet the threshold.")
return []
sorted_indices = threshold_indices[np.argsort(hybrid_scores[threshold_indices])[::-1]]
top_indices = sorted_indices[:top_n]
results = [(filtered_doc_ids[idx], hybrid_scores[idx]) for idx in top_indices]
self.logger.info(f"Top {top_n} results: {results}")
return results
def _rerank_results(self, query, results):
"""
Re-rank the retrieved documents using FlagReranker with normalized scores.
Parameters:
- query (str): The search query.
- results (List[Tuple[str, float]]): A list of (doc_id, score) tuples.
Returns:
- List[Tuple[str, float]]: Re-ranked list of (doc_id, score) tuples with normalized scores.
"""
# Prepare input for the re-ranker
document_texts = [self.bm25_search.get_document(doc_id) for doc_id, _ in results]
doc_ids = [doc_id for doc_id, _ in results]
# Generate pairwise scores using the FlagReranker
rerank_inputs = [[query, doc] for doc in document_texts]
with torch.no_grad():
rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)
# rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)
# Combine doc_ids with normalized re-rank scores and sort by scores
reranked_results = sorted(
zip(doc_ids, rerank_scores),
key=lambda x: x[1],
reverse=True
)
# Log and return results
# self.logger.info(f"Re-ranked results with normalized scores: {reranked_results}")
return reranked_results