Spaces:
Sleeping
Sleeping
""" | |
This module provides custom implementation of a document retriever, designed for multi-stage retrieval. | |
The system uses ensemble methods combining BM25 and Chroma Embeddings to retrieve relevant documents for a given query. | |
It also utilizes various optimizations like rank fusion and weighted reciprocal rank by Langchain. | |
Classes: | |
-------- | |
- MyEnsembleRetriever: Custom retriever for BM25 and Chroma Embeddings. | |
- MyRetriever: Handles multi-stage retrieval. | |
""" | |
import re | |
import ast | |
import copy | |
import math | |
import logging | |
from typing import Dict, List, Optional | |
from langchain.chains import LLMChain | |
from langchain.schema import BaseRetriever, Document | |
from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForRetrieverRun, | |
CallbackManagerForRetrieverRun, | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from toolkit.utils import Config, clean_text, DocIndexer, IndexerOperator | |
from toolkit.prompts import PromptTemplates | |
prompt_templates = PromptTemplates() | |
configs = Config("configparser.ini") | |
logger = logging.getLogger(__name__) | |
class MyEnsembleRetriever(EnsembleRetriever): | |
""" | |
Custom retriever for BM24 and Chroma Embeddings | |
""" | |
retrievers: Dict[str, BaseRetriever] | |
def rank_fusion( | |
self, query: str, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
""" | |
Retrieve the results of the retrievers and use rank_fusion_func to get | |
the final result. | |
Args: | |
query: The query to search for. | |
Returns: | |
A list of reranked documents. | |
""" | |
# Get the results of all retrievers. | |
retriever_docs = [] | |
for key, retriever in self.retrievers.items(): | |
if key == "bm25": | |
res = retriever.get_relevant_documents( | |
clean_text(query), | |
callbacks=run_manager.get_child(tag=f"retriever_{key}"), | |
) | |
retriever_docs.append(res) | |
else: | |
res = retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child(tag=f"retriever_{key}") | |
) | |
retriever_docs.append(res) | |
# apply rank fusion | |
fused_documents = self.weighted_reciprocal_rank(retriever_docs) | |
return fused_documents | |
async def arank_fusion( | |
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[Document]: | |
""" | |
Asynchronously retrieve the results of the retrievers | |
and use rank_fusion_func to get the final result. | |
Args: | |
query: The query to search for. | |
Returns: | |
A list of reranked documents. | |
""" | |
# Get the results of all retrievers. | |
retriever_docs = [] | |
for key, retriever in self.retrievers.items(): | |
if key == "bm25": | |
res = retriever.get_relevant_documents( | |
clean_text(query), | |
callbacks=run_manager.get_child(tag=f"retriever_{key}"), | |
) | |
retriever_docs.append(res) | |
# print("retriever_docs 1:", res) | |
else: | |
res = await retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child(tag=f"retriever_{key}") | |
) | |
retriever_docs.append(res) | |
# apply rank fusion | |
fused_documents = self.weighted_reciprocal_rank(retriever_docs) | |
return fused_documents | |
def weighted_reciprocal_rank( | |
self, doc_lists: List[List[Document]] | |
) -> List[Document]: | |
""" | |
Perform weighted Reciprocal Rank Fusion on multiple rank lists. | |
You can find more details about RRF here: | |
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf | |
Args: | |
doc_lists: A list of rank lists, where each rank list contains unique items. | |
Returns: | |
list: The final aggregated list of items sorted by their weighted RRF | |
scores in descending order. | |
""" | |
if len(doc_lists) != len(self.weights): | |
raise ValueError( | |
"Number of rank lists must be equal to the number of weights." | |
) | |
# replace the page_content with the original uncleaned page_content | |
doc_lists_ = copy.copy(doc_lists) | |
for doc_list in doc_lists_: | |
for doc in doc_list: | |
doc.page_content = doc.metadata["page_content"] | |
# doc.metadata["page_content"] = None | |
# Create a union of all unique documents in the input doc_lists | |
all_documents = set() | |
for doc_list in doc_lists_: | |
for doc in doc_list: | |
all_documents.add(doc.page_content) | |
# Initialize the RRF score dictionary for each document | |
rrf_score_dic = {doc: 0.0 for doc in all_documents} | |
# Calculate RRF scores for each document | |
for doc_list, weight in zip(doc_lists_, self.weights): | |
for rank, doc in enumerate(doc_list, start=1): | |
rrf_score = weight * (1 / (rank + self.c)) | |
rrf_score_dic[doc.page_content] += rrf_score | |
# Sort documents by their RRF scores in descending order | |
sorted_documents = sorted( | |
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True | |
) | |
# Map the sorted page_content back to the original document objects | |
page_content_to_doc_map = { | |
doc.page_content: doc for doc_list in doc_lists_ for doc in doc_list | |
} | |
sorted_docs = [ | |
page_content_to_doc_map[page_content] for page_content in sorted_documents | |
] | |
return sorted_docs | |
class MyRetriever: | |
""" | |
Retriever class to handle multi-stage retrieval. | |
""" | |
def __init__( | |
self, | |
llm, | |
embedding_chunks_small: List[Document], | |
embedding_chunks_medium: List[Document], | |
docs_chunks_small: DocIndexer, | |
docs_chunks_medium: DocIndexer, | |
first_retrieval_k: int, | |
second_retrieval_k: int, | |
num_windows: int, | |
retriever_weights: List[float], | |
): | |
""" | |
Initialize the MyRetriever class. | |
Args: | |
llm: Language model for retrieval. | |
embedding_chunks_small (List[Document]): List of small embedding chunks. | |
embedding_chunks_medium (List[Document]): List of medium embedding chunks. | |
docs_chunks_small (DocIndexer): Document indexer for small chunks. | |
docs_chunks_medium (DocIndexer): Document indexer for medium chunks. | |
first_retrieval_k (int): Number of top documents to retrieve in first retrieval. | |
second_retrieval_k (int): Number of top documents to retrieve in second retrieval. | |
num_windows (int): Number of overlapping windows to consider. | |
retriever_weights (List[float]): Weights for ensemble retrieval. | |
""" | |
self.llm = llm | |
self.embedding_chunks_small = embedding_chunks_small | |
self.embedding_chunks_medium = embedding_chunks_medium | |
self.docs_index_small = DocIndexer(docs_chunks_small) | |
self.docs_index_medium = DocIndexer(docs_chunks_medium) | |
self.first_retrieval_k = first_retrieval_k | |
self.second_retrieval_k = second_retrieval_k | |
self.num_windows = num_windows | |
self.retriever_weights = retriever_weights | |
def get_retriever( | |
self, | |
docs_chunks, | |
emb_chunks, | |
emb_filter=None, | |
k=2, | |
weights=(0.5, 0.5), | |
): | |
""" | |
Initialize and return a retriever instance with specified parameters. | |
Args: | |
docs_chunks: The document chunks for the BM25 retriever. | |
emb_chunks: The document chunks for the Embedding retriever. | |
emb_filter: A filter for embedding retriever. | |
k (int): The number of top documents to return. | |
weights (list): Weights for ensemble retrieval. | |
Returns: | |
MyEnsembleRetriever: An instance of MyEnsembleRetriever. | |
""" | |
bm25_retriever = BM25Retriever.from_documents(docs_chunks) | |
bm25_retriever.k = k | |
emb_retriever = emb_chunks.as_retriever( | |
search_kwargs={ | |
"filter": emb_filter, | |
"k": k, | |
"search_type": "mmr", | |
} | |
) | |
return MyEnsembleRetriever( | |
retrievers={"bm25": bm25_retriever, "chroma": emb_retriever}, | |
weights=weights, | |
) | |
def find_overlaps(self, doc: List[Document]): | |
""" | |
Find overlapping intervals of windows. | |
Args: | |
doc (Document): A document object to find overlaps in. | |
Returns: | |
list: A list of overlapping intervals. | |
""" | |
intervals = [] | |
for item in doc: | |
intervals.append( | |
( | |
item.metadata["large_chunks_idx_lower_bound"], | |
item.metadata["large_chunks_idx_upper_bound"], | |
) | |
) | |
remaining_intervals, grouped_intervals, centroids = intervals.copy(), [], [] | |
while remaining_intervals: | |
curr_interval = remaining_intervals.pop(0) | |
curr_group = [curr_interval] | |
subset_interval = None | |
for start, end in remaining_intervals.copy(): | |
for s, e in curr_group: | |
overlap = set(range(s, e + 1)) & set(range(start, end + 1)) | |
if overlap: | |
curr_group.append((start, end)) | |
remaining_intervals.remove((start, end)) | |
if set(range(start, end + 1)).issubset(set(range(s, e + 1))): | |
subset_interval = (start, end) | |
break | |
if subset_interval: | |
centroid = [math.ceil((subset_interval[0] + subset_interval[1]) / 2)] | |
elif len(curr_group) > 2: | |
first_overlap = max( | |
set(range(curr_group[0][0], curr_group[0][1] + 1)) | |
& set(range(curr_group[1][0], curr_group[1][1] + 1)) | |
) | |
last_overlap_set = set( | |
range(curr_group[-1][0], curr_group[-1][1] + 1) | |
) & set(range(curr_group[-2][0], curr_group[-2][1] + 1)) | |
if not last_overlap_set: | |
last_overlap = first_overlap # Fallback if no overlap | |
else: | |
last_overlap = min(last_overlap_set) | |
step = 1 if first_overlap <= last_overlap else -1 | |
centroid = list(range(first_overlap, last_overlap + step, step)) | |
else: | |
centroid = [ | |
round( | |
sum([math.ceil((s + e) / 2) for s, e in curr_group]) | |
/ len(curr_group) | |
) | |
] | |
grouped_intervals.append( | |
curr_group if len(curr_group) > 1 else curr_group[0] | |
) | |
centroids.extend(centroid) | |
return centroids | |
def get_filter(self, top_k: int, file_md5: str, doc: List[Document]): | |
""" | |
Create a filter for retrievers based on overlapping intervals. | |
Args: | |
top_k (int): Number of top intervals to consider. | |
file_md5 (str): MD5 hash of the file to filter. | |
doc (List[Document]): List of document objects. | |
Returns: | |
tuple: A tuple of containing dictionary filters for DocIndexer and Chroma retrievers. | |
""" | |
overlaps = self.find_overlaps(doc) | |
if len(overlaps) < 1: | |
raise ValueError("No overlapping intervals found.") | |
overlaps_k = overlaps[:top_k] | |
logger.info("windows_at_2nd_retrieval: %s", overlaps_k) | |
search_dict_docindexer = {"OR": []} | |
search_dict_chroma = {"$or": []} | |
for chunk_idx in overlaps_k: | |
search_dict_docindexer["OR"].append( | |
{ | |
"large_chunks_idx_lower_bound": ( | |
IndexerOperator.LTE, | |
chunk_idx, | |
), | |
"large_chunks_idx_upper_bound": ( | |
IndexerOperator.GTE, | |
chunk_idx, | |
), | |
"source_md5": (IndexerOperator.EQ, file_md5), | |
} | |
) | |
if len(overlaps_k) == 1: | |
search_dict_chroma = { | |
"$and": [ | |
{"large_chunks_idx_lower_bound": {"$lte": overlaps_k[0]}}, | |
{"large_chunks_idx_upper_bound": {"$gte": overlaps_k[0]}}, | |
{"source_md5": {"$eq": file_md5}}, | |
] | |
} | |
else: | |
search_dict_chroma["$or"].append( | |
{ | |
"$and": [ | |
{"large_chunks_idx_lower_bound": {"$lte": chunk_idx}}, | |
{"large_chunks_idx_upper_bound": {"$gte": chunk_idx}}, | |
{"source_md5": {"$eq": file_md5}}, | |
] | |
} | |
) | |
return search_dict_docindexer, search_dict_chroma | |
def get_relevant_doc_ids(self, docs: List[Document], query: str): | |
""" | |
Get relevant document IDs given a query using an LLM. | |
Args: | |
docs (List[Document]): List of document objects to find relevant IDs in. | |
query (str): The query string. | |
Returns: | |
list: A list of relevant document IDs. | |
""" | |
snippets = "\n\n\n".join( | |
[ | |
f"Context {idx}:\n{{{doc.page_content}}}. {{source: {doc.metadata['source']}}}" | |
for idx, doc in enumerate(docs) | |
] | |
) | |
id_chain = LLMChain( | |
llm=self.llm, | |
prompt=prompt_templates.get_docs_selection_template(configs.model_name), | |
output_key="IDs", | |
) | |
ids = id_chain.run({"query": query, "snippets": snippets}) | |
logger.info("relevant doc ids: %s", ids) | |
pattern = r"\[\s*\d+\s*(?:,\s*\d+\s*)*\]" | |
match = re.search(pattern, ids) | |
if match: | |
return ast.literal_eval(match.group(0)) | |
else: | |
return [] | |
def get_relevant_documents( | |
self, | |
query: str, | |
num_query: int, | |
*, | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> List[Document]: | |
""" | |
Perform multi-stage retrieval to get relevant documents. | |
Args: | |
query (str): The query string. | |
num_query (int): Number of queries. | |
run_manager (Optional[CallbackManagerForChainRun], optional): Callback manager for chain run. | |
Returns: | |
List[Document]: A list of relevant documents. | |
""" | |
# ! First retrieval | |
first_retriever = self.get_retriever( | |
docs_chunks=self.docs_index_small.documents, | |
emb_chunks=self.embedding_chunks_small, | |
emb_filter=None, | |
k=self.first_retrieval_k, | |
weights=self.retriever_weights, | |
) | |
first = first_retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
for doc in first: | |
logger.info("----1st retrieval----: %s", doc) | |
ids_clean = self.get_relevant_doc_ids(first, query) | |
# ids_clean = [0, 1, 2] | |
logger.info("relevant cleaned doc ids: %s", ids_clean) | |
qa_chunks = {} # key is file name, value is a list of relevant documents | |
# res_chunks = [] | |
if ids_clean and isinstance(ids_clean, list): | |
source_md5_dict = {} | |
for ids_c in ids_clean: | |
if ids_c < len(first): | |
if ids_c not in source_md5_dict: | |
source_md5_dict[first[ids_c].metadata["source_md5"]] = [ | |
first[ids_c] | |
] | |
# else: | |
# source_md5_dict[first[ids_c].metadata["source_md5"]].append( | |
# ids_clean[ids_c] | |
# ) | |
if len(source_md5_dict) == 0: | |
source_md5_dict[first[0].metadata["source_md5"]] = [first[0]] | |
num_docs = len(source_md5_dict.keys()) | |
third_num_k = max( | |
1, | |
( | |
int( | |
( | |
configs.max_llm_context | |
/ (configs.base_chunk_size * configs.chunk_scale) | |
) | |
// (num_docs * num_query) | |
) | |
), | |
) | |
for source_md5, docs in source_md5_dict.items(): | |
logger.info( | |
"selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"] | |
) | |
second_docs_chunks = self.docs_index_small.retrieve_metadata( | |
{ | |
"source_md5": (IndexerOperator.EQ, source_md5), | |
} | |
) | |
second_retriever = self.get_retriever( | |
docs_chunks=second_docs_chunks, | |
emb_chunks=self.embedding_chunks_small, | |
emb_filter={"source_md5": source_md5}, | |
k=self.second_retrieval_k, | |
weights=self.retriever_weights, | |
) | |
# ! Second retrieval | |
second = second_retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
for doc in second: | |
logger.info("----2nd retrieval----: %s", doc) | |
docs.extend(second) | |
docindexer_filter, chroma_filter = self.get_filter( | |
self.num_windows, source_md5, docs | |
) | |
third_docs_chunks = self.docs_index_medium.retrieve_metadata( | |
docindexer_filter | |
) | |
third_retriever = self.get_retriever( | |
docs_chunks=third_docs_chunks, | |
emb_chunks=self.embedding_chunks_medium, | |
emb_filter=chroma_filter, | |
k=third_num_k, | |
weights=self.retriever_weights, | |
) | |
# ! Third retrieval | |
third_temp = third_retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
third = third_temp[:third_num_k] | |
# chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"]) | |
for doc in third: | |
logger.info( | |
"----3rd retrieval----page_content: %s", [doc.page_content] | |
) | |
mtdata = doc.metadata | |
mtdata["page_content"] = None | |
logger.info("----3rd retrieval----metadata: %s", mtdata) | |
file_name = third[0].metadata["source"].split("/")[-1] | |
if file_name not in qa_chunks: | |
qa_chunks[file_name] = third | |
else: | |
qa_chunks[file_name].extend(third) | |
return qa_chunks | |
async def aget_relevant_documents( | |
self, | |
query: str, | |
num_query: int, | |
*, | |
run_manager: AsyncCallbackManagerForChainRun, | |
) -> List[Document]: | |
""" | |
Asynchronous version of get_relevant_documents method. | |
Args: | |
query (str): The query string. | |
num_query (int): Number of queries. | |
run_manager (AsyncCallbackManagerForChainRun): Callback manager for asynchronous chain run. | |
Returns: | |
List[Document]: A list of relevant documents. | |
""" | |
# ! First retrieval | |
first_retriever = self.get_retriever( | |
docs_chunks=self.docs_index_small.documents, | |
emb_chunks=self.embedding_chunks_small, | |
emb_filter=None, | |
k=self.first_retrieval_k, | |
weights=self.retriever_weights, | |
) | |
first = await first_retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
for doc in first: | |
logger.info("----1st retrieval----: %s", doc) | |
ids_clean = self.get_relevant_doc_ids(first, query) | |
logger.info("relevant doc ids: %s", ids_clean) | |
qa_chunks = {} # key is file name, value is a list of relevant documents | |
# res_chunks = [] | |
if ids_clean and isinstance(ids_clean, list): | |
source_md5_dict = {} | |
for ids_c in ids_clean: | |
if ids_c < len(first): | |
if ids_c not in source_md5_dict: | |
source_md5_dict[first[ids_c].metadata["source_md5"]] = [ | |
first[ids_c] | |
] | |
# else: | |
# source_md5_dict[first[ids_c].metadata["source_md5"]].append( | |
# ids_clean[ids_c] | |
# ) | |
if len(source_md5_dict) == 0: | |
source_md5_dict[first[0].metadata["source_md5"]] = [first[0]] | |
num_docs = len(source_md5_dict.keys()) | |
third_num_k = max( | |
1, | |
( | |
int( | |
( | |
configs.max_llm_context | |
/ (configs.base_chunk_size * configs.chunk_scale) | |
) | |
// (num_docs * num_query) | |
) | |
), | |
) | |
for source_md5, docs in source_md5_dict.items(): | |
logger.info( | |
"selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"] | |
) | |
second_docs_chunks = self.docs_index_small.retrieve_metadata( | |
{ | |
"source_md5": (IndexerOperator.EQ, source_md5), | |
} | |
) | |
second_retriever = self.get_retriever( | |
docs_chunks=second_docs_chunks, | |
emb_chunks=self.embedding_chunks_small, | |
emb_filter={"source_md5": source_md5}, | |
k=self.second_retrieval_k, | |
weights=self.retriever_weights, | |
) | |
# ! Second retrieval | |
second = await second_retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
for doc in second: | |
logger.info("----2nd retrieval----: %s", doc) | |
docs.extend(second) | |
docindexer_filter, chroma_filter = self.get_filter( | |
self.num_windows, source_md5, docs | |
) | |
third_docs_chunks = self.docs_index_medium.retrieve_metadata( | |
docindexer_filter | |
) | |
third_retriever = self.get_retriever( | |
docs_chunks=third_docs_chunks, | |
emb_chunks=self.embedding_chunks_medium, | |
emb_filter=chroma_filter, | |
k=third_num_k, | |
weights=self.retriever_weights, | |
) | |
# ! Third retrieval | |
third_temp = await third_retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
third = third_temp[:third_num_k] | |
# chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"]) | |
for doc in third: | |
logger.info( | |
"----3rd retrieval----page_content: %s", [doc.page_content] | |
) | |
mtdata = doc.metadata | |
mtdata["page_content"] = None | |
logger.info("----3rd retrieval----metadata: %s", mtdata) | |
file_name = third[0].metadata["source"].split("/")[-1] | |
if file_name not in qa_chunks: | |
qa_chunks[file_name] = third | |
else: | |
qa_chunks[file_name].extend(third) | |
return qa_chunks | |