|
import os |
|
import numpy as np |
|
import hnswlib |
|
from typing import List, Dict, Any |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from rank_bm25 import BM25Okapi |
|
|
|
from src.config import RetrieverConfig |
|
from src.utils import logger |
|
|
|
|
|
class Retriever: |
|
""" |
|
Hybrid retriever combining BM25 sparse and dense retrieval (no Redis). |
|
""" |
|
def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig): |
|
""" |
|
Initialize the retriever with chunks and configuration. |
|
|
|
Args: |
|
chunks (List[Dict[str, Any]]): List of chunks, where each chunk is a dictionary. |
|
config (RetrieverConfig): Configuration for the retriever. |
|
""" |
|
self.chunks = chunks |
|
try: |
|
if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks): |
|
logger.error("Chunks must be a list of dicts.") |
|
raise ValueError("Chunks must be a list of dicts.") |
|
corpus = [c.get('narration', '').split() for c in chunks] |
|
self.bm25 = BM25Okapi(corpus) |
|
self.embedder = SentenceTransformer(config.DENSE_MODEL) |
|
dim = len(self.embedder.encode(["test"])[0]) |
|
self.ann = hnswlib.Index(space='cosine', dim=dim) |
|
self.ann.init_index(max_elements=len(chunks)) |
|
embeddings = self.embedder.encode([c.get('narration', '') for c in chunks]) |
|
self.ann.add_items(embeddings, ids=list(range(len(chunks)))) |
|
self.ann.set_ef(config.ANN_TOP) |
|
except Exception as e: |
|
logger.error(f"Retriever init failed: {e}") |
|
self.bm25 = None |
|
self.embedder = None |
|
self.ann = None |
|
|
|
def retrieve_sparse(self, query: str, top_k: int) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve chunks using BM25 sparse retrieval. |
|
|
|
Args: |
|
query (str): Query string. |
|
top_k (int): Number of top chunks to return. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: List of top chunks. |
|
""" |
|
if not self.bm25: |
|
logger.error("BM25 not initialized.") |
|
return [] |
|
tokenized = query.split() |
|
try: |
|
scores = self.bm25.get_scores(tokenized) |
|
top_indices = np.argsort(scores)[::-1][:top_k] |
|
return [self.chunks[i] for i in top_indices] |
|
except Exception as e: |
|
logger.error(f"Sparse retrieval failed: {e}") |
|
return [] |
|
|
|
def retrieve_dense(self, query: str, top_k: int) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve chunks using dense retrieval. |
|
|
|
Args: |
|
query (str): Query string. |
|
top_k (int): Number of top chunks to return. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: List of top chunks. |
|
""" |
|
if not self.ann or not self.embedder: |
|
logger.error("Dense retriever not initialized.") |
|
return [] |
|
try: |
|
q_emb = self.embedder.encode([query])[0] |
|
labels, distances = self.ann.knn_query(q_emb, k=top_k) |
|
return [self.chunks[i] for i in labels[0]] |
|
except Exception as e: |
|
logger.error(f"Dense retrieval failed: {e}") |
|
return [] |
|
|
|
def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve chunks using hybrid retrieval. |
|
|
|
Args: |
|
query (str): Query string. |
|
top_k (int, optional): Number of top chunks to return. Defaults to None. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: List of top chunks. |
|
""" |
|
if top_k is None: |
|
top_k = RetrieverConfig.TOP_K |
|
sparse = self.retrieve_sparse(query, top_k) |
|
dense = self.retrieve_dense(query, top_k) |
|
seen = set() |
|
combined = [] |
|
for c in sparse + dense: |
|
cid = id(c) |
|
if cid not in seen: |
|
seen.add(cid) |
|
combined.append(c) |
|
return combined |