|
import sys |
|
import os |
|
from contextlib import contextmanager |
|
|
|
from langchain_core.tools import tool |
|
from langchain_core.runnables import chain |
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough |
|
from langchain_core.runnables import RunnableLambda |
|
|
|
from ..reranker import rerank_docs, rerank_and_sort_docs |
|
|
|
from ...knowledge.openalex import OpenAlexRetriever |
|
from .keywords_extraction import make_keywords_extraction_chain |
|
from ..utils import log_event |
|
from langchain_core.vectorstores import VectorStore |
|
from typing import List |
|
from langchain_core.documents.base import Document |
|
from ..llm import get_llm |
|
from .prompts import retrieve_chapter_prompt_template |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from ..vectorstore import get_pinecone_vectorstore |
|
from ..embeddings import get_embeddings_function |
|
|
|
|
|
import asyncio |
|
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
def divide_into_parts(target, parts): |
|
|
|
base = target // parts |
|
|
|
remainder = target % parts |
|
|
|
result = [] |
|
|
|
for i in range(parts): |
|
if i < remainder: |
|
|
|
result.append(base + 1) |
|
else: |
|
|
|
result.append(base) |
|
|
|
return result |
|
|
|
|
|
@contextmanager |
|
def suppress_output(): |
|
|
|
with open(os.devnull, 'w') as devnull: |
|
|
|
old_stdout = sys.stdout |
|
old_stderr = sys.stderr |
|
|
|
sys.stdout = devnull |
|
sys.stderr = devnull |
|
try: |
|
yield |
|
finally: |
|
|
|
sys.stdout = old_stdout |
|
sys.stderr = old_stderr |
|
|
|
|
|
@tool |
|
def query_retriever(question): |
|
"""Just a dummy tool to simulate the retriever query""" |
|
return question |
|
|
|
def _add_sources_used_in_metadata(docs,sources,question,index): |
|
for doc in docs: |
|
doc.metadata["sources_used"] = sources |
|
doc.metadata["question_used"] = question |
|
doc.metadata["index_used"] = index |
|
return docs |
|
|
|
def _get_k_summary_by_question(n_questions): |
|
if n_questions == 0: |
|
return 0 |
|
elif n_questions == 1: |
|
return 5 |
|
elif n_questions == 2: |
|
return 3 |
|
elif n_questions == 3: |
|
return 2 |
|
else: |
|
return 1 |
|
|
|
def _get_k_images_by_question(n_questions): |
|
if n_questions == 0: |
|
return 0 |
|
elif n_questions == 1: |
|
return 7 |
|
elif n_questions == 2: |
|
return 5 |
|
elif n_questions == 3: |
|
return 3 |
|
else: |
|
return 1 |
|
|
|
def _add_metadata_and_score(docs: List) -> Document: |
|
|
|
docs_with_metadata = [] |
|
for i,(doc,score) in enumerate(docs): |
|
doc.page_content = doc.page_content.replace("\r\n"," ") |
|
doc.metadata["similarity_score"] = score |
|
doc.metadata["content"] = doc.page_content |
|
if doc.metadata["page_number"] != "N/A": |
|
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1 |
|
else: |
|
doc.metadata["page_number"] = 1 |
|
|
|
docs_with_metadata.append(doc) |
|
return docs_with_metadata |
|
|
|
def remove_duplicates_chunks(docs): |
|
|
|
docs = sorted(docs,key=lambda x: x[1],reverse=True) |
|
seen = set() |
|
result = [] |
|
for doc in docs: |
|
if doc[0].page_content not in seen: |
|
seen.add(doc[0].page_content) |
|
result.append(doc) |
|
return result |
|
|
|
def get_ToCs(version: str) : |
|
|
|
filters_text = { |
|
"chunk_type":"toc", |
|
"version": version |
|
} |
|
embeddings_function = get_embeddings_function() |
|
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2") |
|
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text) |
|
|
|
|
|
tocs = remove_duplicates_chunks(tocs) |
|
|
|
return tocs |
|
|
|
async def get_POC_relevant_documents( |
|
query: str, |
|
vectorstore:VectorStore, |
|
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"], |
|
search_figures:bool = False, |
|
search_only:bool = False, |
|
k_documents:int = 10, |
|
threshold:float = 0.6, |
|
k_images: int = 5, |
|
reports:list = [], |
|
min_size:int = 200, |
|
) : |
|
|
|
filters = {} |
|
docs_question = [] |
|
docs_images = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filters_text = { |
|
**filters, |
|
"chunk_type":"text", |
|
|
|
} |
|
|
|
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents) |
|
|
|
docs_question = remove_duplicates_chunks(docs_question) |
|
docs_question = [x for x in docs_question if x[1] > threshold] |
|
|
|
if search_figures: |
|
|
|
filters_image = { |
|
**filters, |
|
"chunk_type":"image" |
|
} |
|
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
|
|
|
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images) |
|
|
|
docs_question = [x for x in docs_question if len(x.page_content) > min_size] |
|
|
|
return { |
|
"docs_question" : docs_question, |
|
"docs_images" : docs_images |
|
} |
|
|
|
async def get_POC_documents_by_ToC_relevant_documents( |
|
query: str, |
|
tocs: list, |
|
vectorstore:VectorStore, |
|
version: str, |
|
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"], |
|
search_figures:bool = False, |
|
search_only:bool = False, |
|
k_documents:int = 10, |
|
threshold:float = 0.6, |
|
k_images: int = 5, |
|
reports:list = [], |
|
min_size:int = 200, |
|
proportion: float = 0.5, |
|
) : |
|
""" |
|
Args: |
|
- tocs : list with the table of contents of each document |
|
- version : version of the parsed documents (e.g. "v4") |
|
- proportion : share of documents retrieved using ToCs |
|
""" |
|
|
|
filters = {} |
|
docs_question = [] |
|
docs_images = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
k_documents_toc = round(k_documents * proportion) |
|
|
|
relevant_tocs = await get_relevant_toc_level_for_query(query, tocs) |
|
|
|
print(f"Relevant ToCs : {relevant_tocs}") |
|
|
|
toc_filters = [toc['chapter'] for toc in relevant_tocs] |
|
|
|
filters_text_toc = { |
|
**filters, |
|
"chunk_type":"text", |
|
"toc_level0": {"$in": toc_filters}, |
|
"version": version |
|
|
|
} |
|
|
|
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc) |
|
|
|
filters_text = { |
|
**filters, |
|
"chunk_type":"text", |
|
"version": version |
|
|
|
} |
|
|
|
docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc) |
|
|
|
|
|
docs_question = remove_duplicates_chunks(docs_question) |
|
docs_question = [x for x in docs_question if x[1] > threshold] |
|
|
|
if search_figures: |
|
|
|
filters_image = { |
|
**filters, |
|
"chunk_type":"image" |
|
} |
|
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
|
|
|
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images) |
|
|
|
docs_question = [x for x in docs_question if len(x.page_content) > min_size] |
|
|
|
return { |
|
"docs_question" : docs_question, |
|
"docs_images" : docs_images |
|
} |
|
|
|
|
|
async def get_IPCC_relevant_documents( |
|
query: str, |
|
vectorstore:VectorStore, |
|
sources:list = ["IPCC","IPBES","IPOS"], |
|
search_figures:bool = False, |
|
reports:list = [], |
|
threshold:float = 0.6, |
|
k_summary:int = 3, |
|
k_total:int = 10, |
|
k_images: int = 5, |
|
namespace:str = "vectors", |
|
min_size:int = 200, |
|
search_only:bool = False, |
|
) : |
|
|
|
|
|
assert isinstance(sources,list) |
|
assert sources |
|
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources]) |
|
assert k_total > k_summary, "k_total should be greater than k_summary" |
|
|
|
|
|
filters = {} |
|
|
|
if len(reports) > 0: |
|
filters["short_name"] = {"$in":reports} |
|
else: |
|
filters["source"] = { "$in": sources} |
|
|
|
|
|
docs_summaries = [] |
|
docs_full = [] |
|
docs_images = [] |
|
|
|
if search_only: |
|
|
|
if search_figures: |
|
filters_image = { |
|
**filters, |
|
"chunk_type":"image" |
|
} |
|
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
|
docs_images = _add_metadata_and_score(docs_images) |
|
else: |
|
|
|
|
|
filters_summaries = { |
|
**filters, |
|
"chunk_type":"text", |
|
"report_type": { "$in":["SPM"]}, |
|
} |
|
|
|
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary) |
|
docs_summaries = [x for x in docs_summaries if x[1] > threshold] |
|
|
|
|
|
filters_full = { |
|
**filters, |
|
"chunk_type":"text", |
|
"report_type": { "$nin":["SPM"]}, |
|
} |
|
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total) |
|
|
|
if search_figures: |
|
|
|
filters_image = { |
|
**filters, |
|
"chunk_type":"image" |
|
} |
|
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
|
|
|
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images) |
|
|
|
|
|
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size] |
|
docs_full = [x for x in docs_full if len(x.page_content) > min_size] |
|
|
|
return { |
|
"docs_summaries" : docs_summaries, |
|
"docs_full" : docs_full, |
|
"docs_images" : docs_images, |
|
} |
|
|
|
|
|
|
|
def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question): |
|
|
|
if source_type == "IPx": |
|
docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)] |
|
elif source_type == "POC" : |
|
docs_question = docs_question_dict["docs_question"][:k_by_question] |
|
else : |
|
raise ValueError("source_type should be either Vector or POC") |
|
|
|
|
|
images_question = docs_question_dict["docs_images"][:k_images_by_question] |
|
|
|
return docs_question, images_question |
|
|
|
|
|
|
|
|
|
|
|
async def retrieve_documents( |
|
current_question: Dict[str, Any], |
|
config: Dict[str, Any], |
|
source_type: str, |
|
vectorstore: VectorStore, |
|
reranker: Any, |
|
version: str = "", |
|
search_figures: bool = False, |
|
search_only: bool = False, |
|
reports: list = [], |
|
rerank_by_question: bool = True, |
|
k_images_by_question: int = 5, |
|
k_before_reranking: int = 100, |
|
k_by_question: int = 5, |
|
k_summary_by_question: int = 3, |
|
tocs: list = [], |
|
by_toc=False |
|
) -> Tuple[List[Document], List[Document]]: |
|
""" |
|
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources |
|
|
|
Args: |
|
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions. |
|
current_question (dict): The current question being processed. |
|
config (dict): Configuration settings for logging and other purposes. |
|
vectorstore (object): The vector store used to retrieve relevant documents. |
|
reranker (object): The reranker used to rerank the retrieved documents. |
|
llm (object): The language model used for processing. |
|
rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True. |
|
k_final (int, optional): The final number of documents to retrieve. Defaults to 15. |
|
k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100. |
|
k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5. |
|
k_images (int, optional): The number of image documents to retrieve. Defaults to 5. |
|
Returns: |
|
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions. |
|
""" |
|
sources = current_question["sources"] |
|
question = current_question["question"] |
|
index = current_question["index"] |
|
source_type = current_question["source_type"] |
|
|
|
print(f"Retrieve documents for question: {question}") |
|
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) |
|
|
|
print(f"""---- Retrieve documents from {current_question["source_type"]}----""") |
|
|
|
|
|
if source_type == "IPx": |
|
docs_question_dict = await get_IPCC_relevant_documents( |
|
query = question, |
|
vectorstore=vectorstore, |
|
search_figures = search_figures, |
|
sources = sources, |
|
min_size = 200, |
|
k_summary = k_before_reranking-1, |
|
k_total = k_before_reranking, |
|
k_images = k_images_by_question, |
|
threshold = 0.5, |
|
search_only = search_only, |
|
reports = reports, |
|
) |
|
|
|
if source_type == 'POC': |
|
if by_toc == True: |
|
print("---- Retrieve documents by ToC----") |
|
docs_question_dict = await get_POC_documents_by_ToC_relevant_documents( |
|
query=question, |
|
tocs = tocs, |
|
vectorstore=vectorstore, |
|
version=version, |
|
search_figures = search_figures, |
|
sources = sources, |
|
threshold = 0.5, |
|
search_only = search_only, |
|
reports = reports, |
|
min_size= 200, |
|
k_documents= k_before_reranking, |
|
k_images= k_by_question |
|
) |
|
else : |
|
docs_question_dict = await get_POC_relevant_documents( |
|
query = question, |
|
vectorstore=vectorstore, |
|
search_figures = search_figures, |
|
sources = sources, |
|
threshold = 0.5, |
|
search_only = search_only, |
|
reports = reports, |
|
min_size= 200, |
|
k_documents= k_before_reranking, |
|
k_images= k_by_question |
|
) |
|
|
|
|
|
if reranker is not None and rerank_by_question: |
|
with suppress_output(): |
|
for key in docs_question_dict.keys(): |
|
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question) |
|
else: |
|
|
|
for doc in docs_question: |
|
doc.metadata["reranking_score"] = doc.metadata["similarity_score"] |
|
|
|
|
|
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question) |
|
|
|
|
|
if reranker is not None and rerank_by_question: |
|
docs_question = rerank_and_sort_docs(reranker, docs_question, question) |
|
|
|
|
|
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index) |
|
images_question = _add_sources_used_in_metadata(images_question,sources,question,index) |
|
|
|
return docs_question, images_question |
|
|
|
|
|
async def retrieve_documents_for_all_questions( |
|
search_figures, |
|
search_only, |
|
reports, |
|
questions_list, |
|
n_questions, |
|
config, |
|
source_type, |
|
to_handle_questions_index, |
|
vectorstore, |
|
reranker, |
|
rerank_by_question=True, |
|
k_final=15, |
|
k_before_reranking=100, |
|
version: str = "", |
|
tocs: list[dict] = [], |
|
by_toc: bool = False |
|
): |
|
""" |
|
Retrieve documents in parallel for all questions. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
k_by_question = k_final // n_questions |
|
k_summary_by_question = _get_k_summary_by_question(n_questions) |
|
k_images_by_question = _get_k_images_by_question(n_questions) |
|
k_before_reranking=100 |
|
|
|
print(f"Source type here is {source_type}") |
|
tasks = [ |
|
retrieve_documents( |
|
current_question=question, |
|
config=config, |
|
source_type=source_type, |
|
vectorstore=vectorstore, |
|
reranker=reranker, |
|
search_figures=search_figures, |
|
search_only=search_only, |
|
reports=reports, |
|
rerank_by_question=rerank_by_question, |
|
k_images_by_question=k_images_by_question, |
|
k_before_reranking=k_before_reranking, |
|
k_by_question=k_by_question, |
|
k_summary_by_question=k_summary_by_question, |
|
tocs=tocs, |
|
version=version, |
|
by_toc=by_toc |
|
) |
|
for i, question in enumerate(questions_list) if i in to_handle_questions_index |
|
] |
|
results = await asyncio.gather(*tasks) |
|
|
|
new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index} |
|
for docs_question, images_question in results: |
|
new_state["documents"].extend(docs_question) |
|
new_state["related_contents"].extend(images_question) |
|
return new_state |
|
|
|
|
|
async def get_relevant_toc_level_for_query( |
|
query: str, |
|
tocs: list[Document], |
|
) -> list[dict] : |
|
|
|
doc_list = [] |
|
for doc in tocs: |
|
doc_name = doc[0].metadata['name'] |
|
toc = doc[0].page_content |
|
doc_list.append({'document': doc_name, 'toc': toc}) |
|
|
|
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
|
|
|
prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template) |
|
chain = prompt | llm | StrOutputParser() |
|
response = chain.invoke({"query": query, "doc_list": doc_list}) |
|
|
|
try: |
|
relevant_tocs = eval(response) |
|
except Exception as e: |
|
print(f" Failed to parse the result because of : {e}") |
|
|
|
return relevant_tocs |
|
|
|
|
|
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): |
|
|
|
async def retrieve_IPx_docs(state, config): |
|
source_type = "IPx" |
|
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"] |
|
|
|
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] |
|
search_only = state["search_only"] |
|
reports = state["reports"] |
|
questions_list = state["questions_list"] |
|
n_questions=state["n_questions"]["total"] |
|
|
|
state = await retrieve_documents_for_all_questions( |
|
search_figures=search_figures, |
|
search_only=search_only, |
|
reports=reports, |
|
questions_list=questions_list, |
|
n_questions=n_questions, |
|
config=config, |
|
source_type=source_type, |
|
to_handle_questions_index=IPx_questions_index, |
|
vectorstore=vectorstore, |
|
reranker=reranker, |
|
rerank_by_question=rerank_by_question, |
|
k_final=k_final, |
|
k_before_reranking=k_before_reranking, |
|
) |
|
return state |
|
|
|
return retrieve_IPx_docs |
|
|
|
|
|
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): |
|
|
|
async def retrieve_POC_docs_node(state, config): |
|
if "POC region" not in state["relevant_content_sources_selection"] : |
|
return {} |
|
|
|
source_type = "POC" |
|
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] |
|
|
|
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] |
|
search_only = state["search_only"] |
|
reports = state["reports"] |
|
questions_list = state["questions_list"] |
|
n_questions=state["n_questions"]["total"] |
|
|
|
state = await retrieve_documents_for_all_questions( |
|
search_figures=search_figures, |
|
search_only=search_only, |
|
reports=reports, |
|
questions_list=questions_list, |
|
n_questions=n_questions, |
|
config=config, |
|
source_type=source_type, |
|
to_handle_questions_index=POC_questions_index, |
|
vectorstore=vectorstore, |
|
reranker=reranker, |
|
rerank_by_question=rerank_by_question, |
|
k_final=k_final, |
|
k_before_reranking=k_before_reranking, |
|
) |
|
return state |
|
|
|
return retrieve_POC_docs_node |
|
|
|
|
|
def make_POC_by_ToC_retriever_node( |
|
vectorstore: VectorStore, |
|
reranker, |
|
llm, |
|
version: str = "", |
|
rerank_by_question=True, |
|
k_final=15, |
|
k_before_reranking=100, |
|
k_summary=5, |
|
): |
|
|
|
async def retrieve_POC_docs_node(state, config): |
|
if "POC region" not in state["relevant_content_sources_selection"] : |
|
return {} |
|
|
|
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"] |
|
search_only = state["search_only"] |
|
search_only = state["search_only"] |
|
reports = state["reports"] |
|
questions_list = state["questions_list"] |
|
n_questions=state["n_questions"]["total"] |
|
|
|
tocs = get_ToCs(version=version) |
|
|
|
source_type = "POC" |
|
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] |
|
|
|
state = await retrieve_documents_for_all_questions( |
|
search_figures=search_figures, |
|
search_only=search_only, |
|
config=config, |
|
reports=reports, |
|
questions_list=questions_list, |
|
n_questions=n_questions, |
|
source_type=source_type, |
|
to_handle_questions_index=POC_questions_index, |
|
vectorstore=vectorstore, |
|
reranker=reranker, |
|
rerank_by_question=rerank_by_question, |
|
k_final=k_final, |
|
k_before_reranking=k_before_reranking, |
|
tocs=tocs, |
|
version=version, |
|
by_toc=True |
|
) |
|
return state |
|
|
|
return retrieve_POC_docs_node |
|
|
|
|
|
|
|
|
|
|
|
|