import sys import os from contextlib import contextmanager from langchain_core.tools import tool from langchain_core.runnables import chain from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_core.runnables import RunnableLambda from ..reranker import rerank_docs, rerank_and_sort_docs # from ...knowledge.retriever import ClimateQARetriever from ...knowledge.openalex import OpenAlexRetriever from .keywords_extraction import make_keywords_extraction_chain from ..utils import log_event from langchain_core.vectorstores import VectorStore from typing import List from langchain_core.documents.base import Document from ..llm import get_llm from .prompts import retrieve_chapter_prompt_template from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from ..vectorstore import get_pinecone_vectorstore from ..embeddings import get_embeddings_function import asyncio from typing import Any, Dict, List, Tuple def divide_into_parts(target, parts): # Base value for each part base = target // parts # Remainder to distribute remainder = target % parts # List to hold the result result = [] for i in range(parts): if i < remainder: # These parts get base value + 1 result.append(base + 1) else: # The rest get the base value result.append(base) return result @contextmanager def suppress_output(): # Open a null device with open(os.devnull, 'w') as devnull: # Store the original stdout and stderr old_stdout = sys.stdout old_stderr = sys.stderr # Redirect stdout and stderr to the null device sys.stdout = devnull sys.stderr = devnull try: yield finally: # Restore stdout and stderr sys.stdout = old_stdout sys.stderr = old_stderr @tool def query_retriever(question): """Just a dummy tool to simulate the retriever query""" return question def _add_sources_used_in_metadata(docs,sources,question,index): for doc in docs: doc.metadata["sources_used"] = sources doc.metadata["question_used"] = question doc.metadata["index_used"] = index return docs def _get_k_summary_by_question(n_questions): if n_questions == 0: return 0 elif n_questions == 1: return 5 elif n_questions == 2: return 3 elif n_questions == 3: return 2 else: return 1 def _get_k_images_by_question(n_questions): if n_questions == 0: return 0 elif n_questions == 1: return 7 elif n_questions == 2: return 5 elif n_questions == 3: return 3 else: return 1 def _add_metadata_and_score(docs: List) -> Document: # Add score to metadata docs_with_metadata = [] for i,(doc,score) in enumerate(docs): doc.page_content = doc.page_content.replace("\r\n"," ") doc.metadata["similarity_score"] = score doc.metadata["content"] = doc.page_content if doc.metadata["page_number"] != "N/A": doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1 else: doc.metadata["page_number"] = 1 # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" docs_with_metadata.append(doc) return docs_with_metadata def remove_duplicates_chunks(docs): # Remove duplicates or almost duplicates docs = sorted(docs,key=lambda x: x[1],reverse=True) seen = set() result = [] for doc in docs: if doc[0].page_content not in seen: seen.add(doc[0].page_content) result.append(doc) return result def get_ToCs(version: str) : filters_text = { "chunk_type":"toc", "version": version } embeddings_function = get_embeddings_function() vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2") tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text) # remove duplicates or almost duplicates tocs = remove_duplicates_chunks(tocs) return tocs async def get_POC_relevant_documents( query: str, vectorstore:VectorStore, sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"], search_figures:bool = False, search_only:bool = False, k_documents:int = 10, threshold:float = 0.6, k_images: int = 5, reports:list = [], min_size:int = 200, ) : # Prepare base search kwargs filters = {} docs_question = [] docs_images = [] # TODO add source selection # if len(reports) > 0: # filters["short_name"] = {"$in":reports} # else: # filters["source"] = { "$in": sources} filters_text = { **filters, "chunk_type":"text", # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question } docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents) # remove duplicates or almost duplicates docs_question = remove_duplicates_chunks(docs_question) docs_question = [x for x in docs_question if x[1] > threshold] if search_figures: # Images filters_image = { **filters, "chunk_type":"image" } docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images) docs_question = [x for x in docs_question if len(x.page_content) > min_size] return { "docs_question" : docs_question, "docs_images" : docs_images } async def get_POC_documents_by_ToC_relevant_documents( query: str, tocs: list, vectorstore:VectorStore, version: str, sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"], search_figures:bool = False, search_only:bool = False, k_documents:int = 10, threshold:float = 0.6, k_images: int = 5, reports:list = [], min_size:int = 200, proportion: float = 0.5, ) : """ Args: - tocs : list with the table of contents of each document - version : version of the parsed documents (e.g. "v4") - proportion : share of documents retrieved using ToCs """ # Prepare base search kwargs filters = {} docs_question = [] docs_images = [] # TODO add source selection # if len(reports) > 0: # filters["short_name"] = {"$in":reports} # else: # filters["source"] = { "$in": sources} k_documents_toc = round(k_documents * proportion) relevant_tocs = await get_relevant_toc_level_for_query(query, tocs) print(f"Relevant ToCs : {relevant_tocs}") # Transform the ToC dict {"document": str, "chapter": str} into a list of string toc_filters = [toc['chapter'] for toc in relevant_tocs] filters_text_toc = { **filters, "chunk_type":"text", "toc_level0": {"$in": toc_filters}, "version": version # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question } docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc) filters_text = { **filters, "chunk_type":"text", "version": version # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question } docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc) # remove duplicates or almost duplicates docs_question = remove_duplicates_chunks(docs_question) docs_question = [x for x in docs_question if x[1] > threshold] if search_figures: # Images filters_image = { **filters, "chunk_type":"image" } docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images) docs_question = [x for x in docs_question if len(x.page_content) > min_size] return { "docs_question" : docs_question, "docs_images" : docs_images } async def get_IPCC_relevant_documents( query: str, vectorstore:VectorStore, sources:list = ["IPCC","IPBES","IPOS"], search_figures:bool = False, reports:list = [], threshold:float = 0.6, k_summary:int = 3, k_total:int = 10, k_images: int = 5, namespace:str = "vectors", min_size:int = 200, search_only:bool = False, ) : # Check if all elements in the list are either IPCC or IPBES assert isinstance(sources,list) assert sources assert all([x in ["IPCC","IPBES","IPOS"] for x in sources]) assert k_total > k_summary, "k_total should be greater than k_summary" # Prepare base search kwargs filters = {} if len(reports) > 0: filters["short_name"] = {"$in":reports} else: filters["source"] = { "$in": sources} # INIT docs_summaries = [] docs_full = [] docs_images = [] if search_only: # Only search for images if search_only is True if search_figures: filters_image = { **filters, "chunk_type":"image" } docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) docs_images = _add_metadata_and_score(docs_images) else: # Regular search flow for text and optionally images # Search for k_summary documents in the summaries dataset filters_summaries = { **filters, "chunk_type":"text", "report_type": { "$in":["SPM"]}, } docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary) docs_summaries = [x for x in docs_summaries if x[1] > threshold] # Search for k_total - k_summary documents in the full reports dataset filters_full = { **filters, "chunk_type":"text", "report_type": { "$nin":["SPM"]}, } docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total) if search_figures: # Images filters_image = { **filters, "chunk_type":"image" } docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images) # Filter if length are below threshold docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size] docs_full = [x for x in docs_full if len(x.page_content) > min_size] return { "docs_summaries" : docs_summaries, "docs_full" : docs_full, "docs_images" : docs_images, } def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question): # Keep the right number of documents - The k_summary documents from SPM are placed in front if source_type == "IPx": docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)] elif source_type == "POC" : docs_question = docs_question_dict["docs_question"][:k_by_question] else : raise ValueError("source_type should be either Vector or POC") # docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]][:(k_by_question)] images_question = docs_question_dict["docs_images"][:k_images_by_question] return docs_question, images_question # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results # @chain async def retrieve_documents( current_question: Dict[str, Any], config: Dict[str, Any], source_type: str, vectorstore: VectorStore, reranker: Any, version: str = "", search_figures: bool = False, search_only: bool = False, reports: list = [], rerank_by_question: bool = True, k_images_by_question: int = 5, k_before_reranking: int = 100, k_by_question: int = 5, k_summary_by_question: int = 3, tocs: list = [], by_toc=False ) -> Tuple[List[Document], List[Document]]: """ Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources Args: state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions. current_question (dict): The current question being processed. config (dict): Configuration settings for logging and other purposes. vectorstore (object): The vector store used to retrieve relevant documents. reranker (object): The reranker used to rerank the retrieved documents. llm (object): The language model used for processing. rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True. k_final (int, optional): The final number of documents to retrieve. Defaults to 15. k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100. k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5. k_images (int, optional): The number of image documents to retrieve. Defaults to 5. Returns: dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions. """ sources = current_question["sources"] question = current_question["question"] index = current_question["index"] source_type = current_question["source_type"] print(f"Retrieve documents for question: {question}") await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) print(f"""---- Retrieve documents from {current_question["source_type"]}----""") if source_type == "IPx": docs_question_dict = await get_IPCC_relevant_documents( query = question, vectorstore=vectorstore, search_figures = search_figures, sources = sources, min_size = 200, k_summary = k_before_reranking-1, k_total = k_before_reranking, k_images = k_images_by_question, threshold = 0.5, search_only = search_only, reports = reports, ) if source_type == 'POC': if by_toc == True: print("---- Retrieve documents by ToC----") docs_question_dict = await get_POC_documents_by_ToC_relevant_documents( query=question, tocs = tocs, vectorstore=vectorstore, version=version, search_figures = search_figures, sources = sources, threshold = 0.5, search_only = search_only, reports = reports, min_size= 200, k_documents= k_before_reranking, k_images= k_by_question ) else : docs_question_dict = await get_POC_relevant_documents( query = question, vectorstore=vectorstore, search_figures = search_figures, sources = sources, threshold = 0.5, search_only = search_only, reports = reports, min_size= 200, k_documents= k_before_reranking, k_images= k_by_question ) # Rerank if reranker is not None and rerank_by_question: with suppress_output(): for key in docs_question_dict.keys(): docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question) else: # Add a default reranking score for doc in docs_question: doc.metadata["reranking_score"] = doc.metadata["similarity_score"] # Keep the right number of documents docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question) # Rerank the documents to put the most relevant in front if reranker is not None and rerank_by_question: docs_question = rerank_and_sort_docs(reranker, docs_question, question) # Add sources used in the metadata docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index) images_question = _add_sources_used_in_metadata(images_question,sources,question,index) return docs_question, images_question async def retrieve_documents_for_all_questions( search_figures, search_only, reports, questions_list, n_questions, config, source_type, to_handle_questions_index, vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100, version: str = "", tocs: list[dict] = [], by_toc: bool = False ): """ Retrieve documents in parallel for all questions. """ # to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"] # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source # search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] # search_only = state["search_only"] # reports = state["reports"] # questions_list = state["questions_list"] # k_by_question = k_final // state["n_questions"]["total"] # k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"]) # k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"]) k_by_question = k_final // n_questions k_summary_by_question = _get_k_summary_by_question(n_questions) k_images_by_question = _get_k_images_by_question(n_questions) k_before_reranking=100 print(f"Source type here is {source_type}") tasks = [ retrieve_documents( current_question=question, config=config, source_type=source_type, vectorstore=vectorstore, reranker=reranker, search_figures=search_figures, search_only=search_only, reports=reports, rerank_by_question=rerank_by_question, k_images_by_question=k_images_by_question, k_before_reranking=k_before_reranking, k_by_question=k_by_question, k_summary_by_question=k_summary_by_question, tocs=tocs, version=version, by_toc=by_toc ) for i, question in enumerate(questions_list) if i in to_handle_questions_index ] results = await asyncio.gather(*tasks) # Combine results new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index} for docs_question, images_question in results: new_state["documents"].extend(docs_question) new_state["related_contents"].extend(images_question) return new_state # ToC Retriever async def get_relevant_toc_level_for_query( query: str, tocs: list[Document], ) -> list[dict] : doc_list = [] for doc in tocs: doc_name = doc[0].metadata['name'] toc = doc[0].page_content doc_list.append({'document': doc_name, 'toc': toc}) llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template) chain = prompt | llm | StrOutputParser() response = chain.invoke({"query": query, "doc_list": doc_list}) try: relevant_tocs = eval(response) except Exception as e: print(f" Failed to parse the result because of : {e}") return relevant_tocs def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): async def retrieve_IPx_docs(state, config): source_type = "IPx" IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"] search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] search_only = state["search_only"] reports = state["reports"] questions_list = state["questions_list"] n_questions=state["n_questions"]["total"] state = await retrieve_documents_for_all_questions( search_figures=search_figures, search_only=search_only, reports=reports, questions_list=questions_list, n_questions=n_questions, config=config, source_type=source_type, to_handle_questions_index=IPx_questions_index, vectorstore=vectorstore, reranker=reranker, rerank_by_question=rerank_by_question, k_final=k_final, k_before_reranking=k_before_reranking, ) return state return retrieve_IPx_docs def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): async def retrieve_POC_docs_node(state, config): if "POC region" not in state["relevant_content_sources_selection"] : return {} source_type = "POC" POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] search_only = state["search_only"] reports = state["reports"] questions_list = state["questions_list"] n_questions=state["n_questions"]["total"] state = await retrieve_documents_for_all_questions( search_figures=search_figures, search_only=search_only, reports=reports, questions_list=questions_list, n_questions=n_questions, config=config, source_type=source_type, to_handle_questions_index=POC_questions_index, vectorstore=vectorstore, reranker=reranker, rerank_by_question=rerank_by_question, k_final=k_final, k_before_reranking=k_before_reranking, ) return state return retrieve_POC_docs_node def make_POC_by_ToC_retriever_node( vectorstore: VectorStore, reranker, llm, version: str = "", rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, ): async def retrieve_POC_docs_node(state, config): if "POC region" not in state["relevant_content_sources_selection"] : return {} search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] search_only = state["search_only"] search_only = state["search_only"] reports = state["reports"] questions_list = state["questions_list"] n_questions=state["n_questions"]["total"] tocs = get_ToCs(version=version) source_type = "POC" POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] state = await retrieve_documents_for_all_questions( search_figures=search_figures, search_only=search_only, config=config, reports=reports, questions_list=questions_list, n_questions=n_questions, source_type=source_type, to_handle_questions_index=POC_questions_index, vectorstore=vectorstore, reranker=reranker, rerank_by_question=rerank_by_question, k_final=k_final, k_before_reranking=k_before_reranking, tocs=tocs, version=version, by_toc=True ) return state return retrieve_POC_docs_node