timeki's picture
improve_local_parsing_and_retrieval (#20)
28684d8
raw
history blame
25.2 kB
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs, rerank_and_sort_docs
# from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
from langchain_core.vectorstores import VectorStore
from typing import List
from langchain_core.documents.base import Document
from ..llm import get_llm
from .prompts import retrieve_chapter_prompt_template
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from ..vectorstore import get_pinecone_vectorstore
from ..embeddings import get_embeddings_function
import asyncio
from typing import Any, Dict, List, Tuple
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@tool
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def _add_sources_used_in_metadata(docs,sources,question,index):
for doc in docs:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
return docs
def _get_k_summary_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 5
elif n_questions == 2:
return 3
elif n_questions == 3:
return 2
else:
return 1
def _get_k_images_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 7
elif n_questions == 2:
return 5
elif n_questions == 3:
return 3
else:
return 1
def _add_metadata_and_score(docs: List) -> Document:
# Add score to metadata
docs_with_metadata = []
for i,(doc,score) in enumerate(docs):
doc.page_content = doc.page_content.replace("\r\n"," ")
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
if doc.metadata["page_number"] != "N/A":
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
else:
doc.metadata["page_number"] = 1
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
docs_with_metadata.append(doc)
return docs_with_metadata
def remove_duplicates_chunks(docs):
# Remove duplicates or almost duplicates
docs = sorted(docs,key=lambda x: x[1],reverse=True)
seen = set()
result = []
for doc in docs:
if doc[0].page_content not in seen:
seen.add(doc[0].page_content)
result.append(doc)
return result
def get_ToCs(version: str) :
filters_text = {
"chunk_type":"toc",
"version": version
}
embeddings_function = get_embeddings_function()
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
# remove duplicates or almost duplicates
tocs = remove_duplicates_chunks(tocs)
return tocs
async def get_POC_relevant_documents(
query: str,
vectorstore:VectorStore,
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
search_figures:bool = False,
search_only:bool = False,
k_documents:int = 10,
threshold:float = 0.6,
k_images: int = 5,
reports:list = [],
min_size:int = 200,
) :
# Prepare base search kwargs
filters = {}
docs_question = []
docs_images = []
# TODO add source selection
# if len(reports) > 0:
# filters["short_name"] = {"$in":reports}
# else:
# filters["source"] = { "$in": sources}
filters_text = {
**filters,
"chunk_type":"text",
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
}
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
# remove duplicates or almost duplicates
docs_question = remove_duplicates_chunks(docs_question)
docs_question = [x for x in docs_question if x[1] > threshold]
if search_figures:
# Images
filters_image = {
**filters,
"chunk_type":"image"
}
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
return {
"docs_question" : docs_question,
"docs_images" : docs_images
}
async def get_POC_documents_by_ToC_relevant_documents(
query: str,
tocs: list,
vectorstore:VectorStore,
version: str,
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
search_figures:bool = False,
search_only:bool = False,
k_documents:int = 10,
threshold:float = 0.6,
k_images: int = 5,
reports:list = [],
min_size:int = 200,
proportion: float = 0.5,
) :
"""
Args:
- tocs : list with the table of contents of each document
- version : version of the parsed documents (e.g. "v4")
- proportion : share of documents retrieved using ToCs
"""
# Prepare base search kwargs
filters = {}
docs_question = []
docs_images = []
# TODO add source selection
# if len(reports) > 0:
# filters["short_name"] = {"$in":reports}
# else:
# filters["source"] = { "$in": sources}
k_documents_toc = round(k_documents * proportion)
relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
print(f"Relevant ToCs : {relevant_tocs}")
# Transform the ToC dict {"document": str, "chapter": str} into a list of string
toc_filters = [toc['chapter'] for toc in relevant_tocs]
filters_text_toc = {
**filters,
"chunk_type":"text",
"toc_level0": {"$in": toc_filters},
"version": version
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
}
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
filters_text = {
**filters,
"chunk_type":"text",
"version": version
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
}
docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
# remove duplicates or almost duplicates
docs_question = remove_duplicates_chunks(docs_question)
docs_question = [x for x in docs_question if x[1] > threshold]
if search_figures:
# Images
filters_image = {
**filters,
"chunk_type":"image"
}
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
return {
"docs_question" : docs_question,
"docs_images" : docs_images
}
async def get_IPCC_relevant_documents(
query: str,
vectorstore:VectorStore,
sources:list = ["IPCC","IPBES","IPOS"],
search_figures:bool = False,
reports:list = [],
threshold:float = 0.6,
k_summary:int = 3,
k_total:int = 10,
k_images: int = 5,
namespace:str = "vectors",
min_size:int = 200,
search_only:bool = False,
) :
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(sources,list)
assert sources
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
assert k_total > k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(reports) > 0:
filters["short_name"] = {"$in":reports}
else:
filters["source"] = { "$in": sources}
# INIT
docs_summaries = []
docs_full = []
docs_images = []
if search_only:
# Only search for images if search_only is True
if search_figures:
filters_image = {
**filters,
"chunk_type":"image"
}
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_images = _add_metadata_and_score(docs_images)
else:
# Regular search flow for text and optionally images
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"chunk_type":"text",
"report_type": { "$in":["SPM"]},
}
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"chunk_type":"text",
"report_type": { "$nin":["SPM"]},
}
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
if search_figures:
# Images
filters_image = {
**filters,
"chunk_type":"image"
}
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
# Filter if length are below threshold
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
return {
"docs_summaries" : docs_summaries,
"docs_full" : docs_full,
"docs_images" : docs_images,
}
def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
# Keep the right number of documents - The k_summary documents from SPM are placed in front
if source_type == "IPx":
docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)]
elif source_type == "POC" :
docs_question = docs_question_dict["docs_question"][:k_by_question]
else :
raise ValueError("source_type should be either Vector or POC")
# docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]][:(k_by_question)]
images_question = docs_question_dict["docs_images"][:k_images_by_question]
return docs_question, images_question
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(
current_question: Dict[str, Any],
config: Dict[str, Any],
source_type: str,
vectorstore: VectorStore,
reranker: Any,
version: str = "",
search_figures: bool = False,
search_only: bool = False,
reports: list = [],
rerank_by_question: bool = True,
k_images_by_question: int = 5,
k_before_reranking: int = 100,
k_by_question: int = 5,
k_summary_by_question: int = 3,
tocs: list = [],
by_toc=False
) -> Tuple[List[Document], List[Document]]:
"""
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
Args:
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
current_question (dict): The current question being processed.
config (dict): Configuration settings for logging and other purposes.
vectorstore (object): The vector store used to retrieve relevant documents.
reranker (object): The reranker used to rerank the retrieved documents.
llm (object): The language model used for processing.
rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
Returns:
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
"""
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
source_type = current_question["source_type"]
print(f"Retrieve documents for question: {question}")
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
if source_type == "IPx":
docs_question_dict = await get_IPCC_relevant_documents(
query = question,
vectorstore=vectorstore,
search_figures = search_figures,
sources = sources,
min_size = 200,
k_summary = k_before_reranking-1,
k_total = k_before_reranking,
k_images = k_images_by_question,
threshold = 0.5,
search_only = search_only,
reports = reports,
)
if source_type == 'POC':
if by_toc == True:
print("---- Retrieve documents by ToC----")
docs_question_dict = await get_POC_documents_by_ToC_relevant_documents(
query=question,
tocs = tocs,
vectorstore=vectorstore,
version=version,
search_figures = search_figures,
sources = sources,
threshold = 0.5,
search_only = search_only,
reports = reports,
min_size= 200,
k_documents= k_before_reranking,
k_images= k_by_question
)
else :
docs_question_dict = await get_POC_relevant_documents(
query = question,
vectorstore=vectorstore,
search_figures = search_figures,
sources = sources,
threshold = 0.5,
search_only = search_only,
reports = reports,
min_size= 200,
k_documents= k_before_reranking,
k_images= k_by_question
)
# Rerank
if reranker is not None and rerank_by_question:
with suppress_output():
for key in docs_question_dict.keys():
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
else:
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
# Keep the right number of documents
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
# Rerank the documents to put the most relevant in front
if reranker is not None and rerank_by_question:
docs_question = rerank_and_sort_docs(reranker, docs_question, question)
# Add sources used in the metadata
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
return docs_question, images_question
async def retrieve_documents_for_all_questions(
search_figures,
search_only,
reports,
questions_list,
n_questions,
config,
source_type,
to_handle_questions_index,
vectorstore,
reranker,
rerank_by_question=True,
k_final=15,
k_before_reranking=100,
version: str = "",
tocs: list[dict] = [],
by_toc: bool = False
):
"""
Retrieve documents in parallel for all questions.
"""
# to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
# search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
# search_only = state["search_only"]
# reports = state["reports"]
# questions_list = state["questions_list"]
# k_by_question = k_final // state["n_questions"]["total"]
# k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
# k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
k_by_question = k_final // n_questions
k_summary_by_question = _get_k_summary_by_question(n_questions)
k_images_by_question = _get_k_images_by_question(n_questions)
k_before_reranking=100
print(f"Source type here is {source_type}")
tasks = [
retrieve_documents(
current_question=question,
config=config,
source_type=source_type,
vectorstore=vectorstore,
reranker=reranker,
search_figures=search_figures,
search_only=search_only,
reports=reports,
rerank_by_question=rerank_by_question,
k_images_by_question=k_images_by_question,
k_before_reranking=k_before_reranking,
k_by_question=k_by_question,
k_summary_by_question=k_summary_by_question,
tocs=tocs,
version=version,
by_toc=by_toc
)
for i, question in enumerate(questions_list) if i in to_handle_questions_index
]
results = await asyncio.gather(*tasks)
# Combine results
new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index}
for docs_question, images_question in results:
new_state["documents"].extend(docs_question)
new_state["related_contents"].extend(images_question)
return new_state
# ToC Retriever
async def get_relevant_toc_level_for_query(
query: str,
tocs: list[Document],
) -> list[dict] :
doc_list = []
for doc in tocs:
doc_name = doc[0].metadata['name']
toc = doc[0].page_content
doc_list.append({'document': doc_name, 'toc': toc})
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
chain = prompt | llm | StrOutputParser()
response = chain.invoke({"query": query, "doc_list": doc_list})
try:
relevant_tocs = eval(response)
except Exception as e:
print(f" Failed to parse the result because of : {e}")
return relevant_tocs
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
async def retrieve_IPx_docs(state, config):
source_type = "IPx"
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
search_only = state["search_only"]
reports = state["reports"]
questions_list = state["questions_list"]
n_questions=state["n_questions"]["total"]
state = await retrieve_documents_for_all_questions(
search_figures=search_figures,
search_only=search_only,
reports=reports,
questions_list=questions_list,
n_questions=n_questions,
config=config,
source_type=source_type,
to_handle_questions_index=IPx_questions_index,
vectorstore=vectorstore,
reranker=reranker,
rerank_by_question=rerank_by_question,
k_final=k_final,
k_before_reranking=k_before_reranking,
)
return state
return retrieve_IPx_docs
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
async def retrieve_POC_docs_node(state, config):
if "POC region" not in state["relevant_content_sources_selection"] :
return {}
source_type = "POC"
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
search_only = state["search_only"]
reports = state["reports"]
questions_list = state["questions_list"]
n_questions=state["n_questions"]["total"]
state = await retrieve_documents_for_all_questions(
search_figures=search_figures,
search_only=search_only,
reports=reports,
questions_list=questions_list,
n_questions=n_questions,
config=config,
source_type=source_type,
to_handle_questions_index=POC_questions_index,
vectorstore=vectorstore,
reranker=reranker,
rerank_by_question=rerank_by_question,
k_final=k_final,
k_before_reranking=k_before_reranking,
)
return state
return retrieve_POC_docs_node
def make_POC_by_ToC_retriever_node(
vectorstore: VectorStore,
reranker,
llm,
version: str = "",
rerank_by_question=True,
k_final=15,
k_before_reranking=100,
k_summary=5,
):
async def retrieve_POC_docs_node(state, config):
if "POC region" not in state["relevant_content_sources_selection"] :
return {}
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
search_only = state["search_only"]
search_only = state["search_only"]
reports = state["reports"]
questions_list = state["questions_list"]
n_questions=state["n_questions"]["total"]
tocs = get_ToCs(version=version)
source_type = "POC"
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
state = await retrieve_documents_for_all_questions(
search_figures=search_figures,
search_only=search_only,
config=config,
reports=reports,
questions_list=questions_list,
n_questions=n_questions,
source_type=source_type,
to_handle_questions_index=POC_questions_index,
vectorstore=vectorstore,
reranker=reranker,
rerank_by_question=rerank_by_question,
k_final=k_final,
k_before_reranking=k_before_reranking,
tocs=tocs,
version=version,
by_toc=True
)
return state
return retrieve_POC_docs_node