ai-tutor-chatbot / scripts /custom_retriever.py
omarsol's picture
Initial commit
5ddcfe5
import asyncio
import os
import time
import traceback
from typing import List, Optional
import logfire
import tiktoken
from cohere import AsyncClient
from dotenv import load_dotenv
from llama_index.core import Document, QueryBundle
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.retrievers import (
BaseRetriever,
KeywordTableSimpleRetriever,
VectorIndexRetriever,
)
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
from llama_index.core.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.postprocessor.cohere_rerank.base import CohereRerank
load_dotenv()
class AsyncCohereRerank(CohereRerank):
def __init__(
self,
top_n: int = 5,
model: str = "rerank-english-v3.0",
api_key: Optional[str] = None,
) -> None:
super().__init__(top_n=top_n, model=model, api_key=api_key)
self._api_key = api_key
self._model = model
self._top_n = top_n
async def apostprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []
async_client = AsyncClient(api_key=self._api_key)
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self._model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self._top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
results = await async_client.rerank(
model=self._model,
top_n=self._top_n,
query=query_bundle.query_str,
documents=texts,
)
new_nodes = []
for result in results.results:
new_node_with_score = NodeWithScore(
node=nodes[result.index].node, score=result.relevance_score
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
return new_nodes
class CustomRetriever(BaseRetriever):
"""Custom retriever that performs both semantic search and hybrid search."""
def __init__(
self,
vector_retriever: VectorIndexRetriever,
document_dict: dict,
keyword_retriever=None,
mode: str = "AND",
) -> None:
"""Init params."""
self._vector_retriever = vector_retriever
self._document_dict = document_dict
self._keyword_retriever = keyword_retriever
if mode not in ("AND", "OR"):
raise ValueError("Invalid mode.")
self._mode = mode
super().__init__()
async def _process_retrieval(
self, query_bundle: QueryBundle, is_async: bool = True
) -> List[NodeWithScore]:
"""Common processing logic for both sync and async retrieval."""
# Clean query string
query_bundle.query_str = query_bundle.query_str.replace(
"\ninput is ", ""
).rstrip()
logfire.info(f"Retrieving nodes with string: '{query_bundle}'")
start = time.time()
# Get nodes from both retrievers
if is_async:
nodes = await self._vector_retriever.aretrieve(query_bundle)
else:
nodes = self._vector_retriever.retrieve(query_bundle)
keyword_nodes = []
if self._keyword_retriever:
if is_async:
keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
else:
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
logfire.info(f"Number of vector nodes: {len(nodes)}")
logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}")
# # Filter keyword nodes based on metadata filters from vector retriever
# if (
# hasattr(self._vector_retriever, "_filters")
# and self._vector_retriever._filters
# ):
# filtered_keyword_nodes = []
# for node in keyword_nodes:
# node_source = node.node.metadata.get("source")
# # Check if node's source matches any of the filter conditions
# for filter in self._vector_retriever._filters.filters:
# if (
# isinstance(filter, MetadataFilter)
# and filter.key == "source"
# and filter.operator == FilterOperator.EQ
# and filter.value == node_source
# ):
# filtered_keyword_nodes.append(node)
# break
# keyword_nodes = filtered_keyword_nodes
# logfire.info(
# f"Number of keyword nodes after filtering: {len(keyword_nodes)}"
# )
# Combine results based on mode
vector_ids = {n.node.node_id for n in nodes}
keyword_ids = {n.node.node_id for n in keyword_nodes}
combined_dict = {n.node.node_id: n for n in nodes}
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
# If no keyword retriever or no keyword nodes, just use vector nodes
if not self._keyword_retriever or not keyword_nodes:
retrieve_ids = vector_ids
else:
retrieve_ids = (
vector_ids.intersection(keyword_ids)
if self._mode == "AND"
else vector_ids.union(keyword_ids)
)
nodes = [combined_dict[rid] for rid in retrieve_ids]
logfire.info(f"Number of combined nodes: {len(nodes)}")
# Filter unique doc IDs
nodes = self._filter_nodes_by_unique_doc_id(nodes)
logfire.info(f"Number of nodes without duplicate doc IDs: {len(nodes)}")
# Process node contents
for node in nodes:
doc_id = node.node.source_node.node_id
if node.metadata["retrieve_doc"]:
doc = self._document_dict[doc_id]
node.node.text = doc.text
node.node.node_id = doc_id
# Rerank results
try:
reranker = (
AsyncCohereRerank(top_n=5, model="rerank-english-v3.0")
if is_async
else CohereRerank(top_n=5, model="rerank-english-v3.0")
)
nodes = (
await reranker.apostprocess_nodes(nodes, query_bundle)
if is_async
else reranker.postprocess_nodes(nodes, query_bundle)
)
except Exception as e:
error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n"
error_msg += "Traceback:\n"
error_msg += traceback.format_exc()
logfire.error(error_msg)
# Filter by score and token count
nodes_filtered = self._filter_by_score_and_tokens(nodes)
duration = time.time() - start
logfire.info(f"Retrieving nodes took {duration:.2f}s")
logfire.info(f"Nodes sent to LLM: {nodes_filtered[:5]}")
return nodes_filtered[:5]
def _filter_nodes_by_unique_doc_id(
self, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
"""Filter nodes to keep only unique doc IDs."""
unique_nodes = {}
for node in nodes:
doc_id = node.node.source_node.node_id
if doc_id is not None and doc_id not in unique_nodes:
unique_nodes[doc_id] = node
return list(unique_nodes.values())
def _filter_by_score_and_tokens(
self, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
"""Filter nodes by score and token count."""
nodes_filtered = []
total_tokens = 0
enc = tiktoken.encoding_for_model("gpt-4o-mini")
for node in nodes:
if node.score < 0.10:
logfire.info(f"Skipping node with score {node.score}")
continue
node_tokens = len(enc.encode(node.node.text))
if total_tokens + node_tokens > 100_000:
logfire.info("Skipping node due to token count exceeding 100k")
break
total_tokens += node_tokens
nodes_filtered.append(node)
return nodes_filtered
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Async retrieve nodes given query."""
return await self._process_retrieval(query_bundle, is_async=True)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Sync retrieve nodes given query."""
return asyncio.run(self._process_retrieval(query_bundle, is_async=False))