File size: 3,931 Bytes
3301b3c
 
 
 
 
 
 
 
33f4e34
04db7e0
33f4e34
3301b3c
 
 
33f4e34
3301b3c
 
33f4e34
 
3301b3c
33f4e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3301b3c
 
33f4e34
 
 
 
 
 
 
 
 
 
 
 
 
3301b3c
33f4e34
 
 
 
 
 
 
3301b3c
 
33f4e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3301b3c
33f4e34
 
 
 
 
3301b3c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import numpy as np
import hnswlib
from typing import List, Dict, Any

from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi

from src.config import RetrieverConfig
from src.utils import logger


class Retriever:
    """
    Hybrid retriever combining BM25 sparse and dense retrieval (no Redis).
    """
    def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig):
        """
        Initialize the retriever with chunks and configuration.

        Args:
        chunks (List[Dict[str, Any]]): List of chunks, where each chunk is a dictionary.
        config (RetrieverConfig): Configuration for the retriever.
        """
        self.chunks = chunks
        try:
            if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks):
                logger.error("Chunks must be a list of dicts.")
                raise ValueError("Chunks must be a list of dicts.")
            corpus = [c.get('narration', '').split() for c in chunks]
            self.bm25 = BM25Okapi(corpus)
            self.embedder = SentenceTransformer(config.DENSE_MODEL)
            dim = len(self.embedder.encode(["test"])[0])
            self.ann = hnswlib.Index(space='cosine', dim=dim)
            self.ann.init_index(max_elements=len(chunks))
            embeddings = self.embedder.encode([c.get('narration', '') for c in chunks])
            self.ann.add_items(embeddings, ids=list(range(len(chunks))))
            self.ann.set_ef(config.ANN_TOP)
        except Exception as e:
            logger.error(f"Retriever init failed: {e}")
            self.bm25 = None
            self.embedder = None
            self.ann = None

    def retrieve_sparse(self, query: str, top_k: int) -> List[Dict[str, Any]]:
        """
        Retrieve chunks using BM25 sparse retrieval.

        Args:
        query (str): Query string.
        top_k (int): Number of top chunks to return.

        Returns:
        List[Dict[str, Any]]: List of top chunks.
        """
        if not self.bm25:
            logger.error("BM25 not initialized.")
            return []
        tokenized = query.split()
        try:
            scores = self.bm25.get_scores(tokenized)
            top_indices = np.argsort(scores)[::-1][:top_k]
            return [self.chunks[i] for i in top_indices]
        except Exception as e:
            logger.error(f"Sparse retrieval failed: {e}")
            return []

    def retrieve_dense(self, query: str, top_k: int) -> List[Dict[str, Any]]:
        """
        Retrieve chunks using dense retrieval.

        Args:
        query (str): Query string.
        top_k (int): Number of top chunks to return.

        Returns:
        List[Dict[str, Any]]: List of top chunks.
        """
        if not self.ann or not self.embedder:
            logger.error("Dense retriever not initialized.")
            return []
        try:
            q_emb = self.embedder.encode([query])[0]
            labels, distances = self.ann.knn_query(q_emb, k=top_k)
            return [self.chunks[i] for i in labels[0]]
        except Exception as e:
            logger.error(f"Dense retrieval failed: {e}")
            return []

    def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]:
        """
        Retrieve chunks using hybrid retrieval.

        Args:
        query (str): Query string.
        top_k (int, optional): Number of top chunks to return. Defaults to None.

        Returns:
        List[Dict[str, Any]]: List of top chunks.
        """
        if top_k is None:
            top_k = RetrieverConfig.TOP_K
        sparse = self.retrieve_sparse(query, top_k)
        dense = self.retrieve_dense(query, top_k)
        seen = set()
        combined = []
        for c in sparse + dense:
            cid = id(c)
            if cid not in seen:
                seen.add(cid)
                combined.append(c)
        return combined