File size: 9,247 Bytes
5ddcfe5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import asyncio
import os
import time
import traceback
from typing import List, Optional

import logfire
import tiktoken
from cohere import AsyncClient
from dotenv import load_dotenv
from llama_index.core import Document, QueryBundle
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.retrievers import (
    BaseRetriever,
    KeywordTableSimpleRetriever,
    VectorIndexRetriever,
)
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
from llama_index.core.vector_stores import (
    FilterCondition,
    FilterOperator,
    MetadataFilter,
    MetadataFilters,
)
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.postprocessor.cohere_rerank.base import CohereRerank

load_dotenv()


class AsyncCohereRerank(CohereRerank):
    def __init__(
        self,
        top_n: int = 5,
        model: str = "rerank-english-v3.0",
        api_key: Optional[str] = None,
    ) -> None:
        super().__init__(top_n=top_n, model=model, api_key=api_key)
        self._api_key = api_key
        self._model = model
        self._top_n = top_n

    async def apostprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        if query_bundle is None:
            raise ValueError("Query bundle must be provided.")

        if len(nodes) == 0:
            return []

        async_client = AsyncClient(api_key=self._api_key)

        with self.callback_manager.event(
            CBEventType.RERANKING,
            payload={
                EventPayload.NODES: nodes,
                EventPayload.MODEL_NAME: self._model,
                EventPayload.QUERY_STR: query_bundle.query_str,
                EventPayload.TOP_K: self._top_n,
            },
        ) as event:
            texts = [
                node.node.get_content(metadata_mode=MetadataMode.EMBED)
                for node in nodes
            ]

            results = await async_client.rerank(
                model=self._model,
                top_n=self._top_n,
                query=query_bundle.query_str,
                documents=texts,
            )

            new_nodes = []
            for result in results.results:
                new_node_with_score = NodeWithScore(
                    node=nodes[result.index].node, score=result.relevance_score
                )
                new_nodes.append(new_node_with_score)
            event.on_end(payload={EventPayload.NODES: new_nodes})

        return new_nodes


class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        document_dict: dict,
        keyword_retriever=None,
        mode: str = "AND",
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._document_dict = document_dict
        self._keyword_retriever = keyword_retriever
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    async def _process_retrieval(
        self, query_bundle: QueryBundle, is_async: bool = True
    ) -> List[NodeWithScore]:
        """Common processing logic for both sync and async retrieval."""
        # Clean query string
        query_bundle.query_str = query_bundle.query_str.replace(
            "\ninput is ", ""
        ).rstrip()
        logfire.info(f"Retrieving nodes with string: '{query_bundle}'")

        start = time.time()

        # Get nodes from both retrievers
        if is_async:
            nodes = await self._vector_retriever.aretrieve(query_bundle)
        else:
            nodes = self._vector_retriever.retrieve(query_bundle)

        keyword_nodes = []
        if self._keyword_retriever:
            if is_async:
                keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
            else:
                keyword_nodes = self._keyword_retriever.retrieve(query_bundle)

        logfire.info(f"Number of vector nodes: {len(nodes)}")
        logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}")

        # # Filter keyword nodes based on metadata filters from vector retriever
        # if (
        #     hasattr(self._vector_retriever, "_filters")
        #     and self._vector_retriever._filters
        # ):
        #     filtered_keyword_nodes = []
        #     for node in keyword_nodes:
        #         node_source = node.node.metadata.get("source")
        #         # Check if node's source matches any of the filter conditions
        #         for filter in self._vector_retriever._filters.filters:
        #             if (
        #                 isinstance(filter, MetadataFilter)
        #                 and filter.key == "source"
        #                 and filter.operator == FilterOperator.EQ
        #                 and filter.value == node_source
        #             ):
        #                 filtered_keyword_nodes.append(node)
        #                 break
        #     keyword_nodes = filtered_keyword_nodes
        #     logfire.info(
        #         f"Number of keyword nodes after filtering: {len(keyword_nodes)}"
        #     )

        # Combine results based on mode
        vector_ids = {n.node.node_id for n in nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}
        combined_dict = {n.node.node_id: n for n in nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        # If no keyword retriever or no keyword nodes, just use vector nodes
        if not self._keyword_retriever or not keyword_nodes:
            retrieve_ids = vector_ids
        else:
            retrieve_ids = (
                vector_ids.intersection(keyword_ids)
                if self._mode == "AND"
                else vector_ids.union(keyword_ids)
            )

        nodes = [combined_dict[rid] for rid in retrieve_ids]
        logfire.info(f"Number of combined nodes: {len(nodes)}")

        # Filter unique doc IDs
        nodes = self._filter_nodes_by_unique_doc_id(nodes)
        logfire.info(f"Number of nodes without duplicate doc IDs: {len(nodes)}")

        # Process node contents
        for node in nodes:
            doc_id = node.node.source_node.node_id
            if node.metadata["retrieve_doc"]:
                doc = self._document_dict[doc_id]
                node.node.text = doc.text
            node.node.node_id = doc_id

        # Rerank results
        try:
            reranker = (
                AsyncCohereRerank(top_n=5, model="rerank-english-v3.0")
                if is_async
                else CohereRerank(top_n=5, model="rerank-english-v3.0")
            )
            nodes = (
                await reranker.apostprocess_nodes(nodes, query_bundle)
                if is_async
                else reranker.postprocess_nodes(nodes, query_bundle)
            )
        except Exception as e:
            error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n"
            error_msg += "Traceback:\n"
            error_msg += traceback.format_exc()
            logfire.error(error_msg)

        # Filter by score and token count
        nodes_filtered = self._filter_by_score_and_tokens(nodes)

        duration = time.time() - start
        logfire.info(f"Retrieving nodes took {duration:.2f}s")
        logfire.info(f"Nodes sent to LLM: {nodes_filtered[:5]}")

        return nodes_filtered[:5]

    def _filter_nodes_by_unique_doc_id(
        self, nodes: List[NodeWithScore]
    ) -> List[NodeWithScore]:
        """Filter nodes to keep only unique doc IDs."""
        unique_nodes = {}
        for node in nodes:
            doc_id = node.node.source_node.node_id
            if doc_id is not None and doc_id not in unique_nodes:
                unique_nodes[doc_id] = node
        return list(unique_nodes.values())

    def _filter_by_score_and_tokens(
        self, nodes: List[NodeWithScore]
    ) -> List[NodeWithScore]:
        """Filter nodes by score and token count."""
        nodes_filtered = []
        total_tokens = 0
        enc = tiktoken.encoding_for_model("gpt-4o-mini")

        for node in nodes:
            if node.score < 0.10:
                logfire.info(f"Skipping node with score {node.score}")
                continue

            node_tokens = len(enc.encode(node.node.text))
            if total_tokens + node_tokens > 100_000:
                logfire.info("Skipping node due to token count exceeding 100k")
                break

            total_tokens += node_tokens
            nodes_filtered.append(node)

        return nodes_filtered

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Async retrieve nodes given query."""
        return await self._process_retrieval(query_bundle, is_async=True)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Sync retrieve nodes given query."""
        return asyncio.run(self._process_retrieval(query_bundle, is_async=False))