timeki's picture
improve_local_parsing_and_retrieval (#20)
28684d8
raw
history blame
12.2 kB
import sys
import os
from contextlib import contextmanager
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
from typing_extensions import TypedDict
from typing import List, Dict
import operator
from typing import Annotated
import pandas as pd
from IPython.display import display, HTML, Image
from .chains.answer_chitchat import make_chitchat_node
from .chains.answer_ai_impact import make_ai_impact_node
from .chains.query_transformation import make_query_transform_node
from .chains.translation import make_translation_node
from .chains.intent_categorization import make_intent_categorization_node
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node
from .chains.answer_rag import make_rag_node
from .chains.graph_retriever import make_graph_retriever_node
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
# from .chains.set_defaults import set_defaults
class GraphState(TypedDict):
"""
Represents the state of our graph.
"""
user_input : str
language : str
intent : str
search_graphs_chitchat : bool
query: str
questions_list : List[dict]
handled_questions_index : Annotated[list[int], operator.add]
n_questions : int
answer: str
audience: str = "experts"
sources_input: List[str] = ["IPCC","IPBES"] # Deprecated -> used only graphs that can only be OWID
relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
sources_auto: bool = True
min_year: int = 1960
max_year: int = None
documents: Annotated[List[Document], operator.add]
related_contents : Annotated[List[Document], operator.add] # Images
recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
search_only : bool = False
reports : List[str] = []
def dummy(state):
return
def search(state): #TODO
return
def answer_search(state):#TODO
return
def route_intent(state):
intent = state["intent"]
if intent in ["chitchat","esg"]:
return "answer_chitchat"
# elif intent == "ai_impact":
# return "answer_ai_impact"
else:
# Search route
return "answer_climate"
def chitchat_route_intent(state):
intent = state["search_graphs_chitchat"]
if intent is True:
return END #TODO
elif intent is False:
return END
def route_translation(state):
if state["language"].lower() == "english":
return "transform_query"
else:
return "transform_query"
# return "translate_query" #TODO : add translation
def route_based_on_relevant_docs(state,threshold_docs=0.2):
docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"])
if len(docs) > 0:
return "answer_rag"
else:
return "answer_rag_no_docs"
def route_continue_retrieve_documents(state):
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
if questions_ipx_finished:
return "end_retrieve_IPx_documents"
else:
return "retrieve_documents"
def route_continue_retrieve_local_documents(state):
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
# if questions_poc_finished and state["search_only"]:
# return END
if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]):
return "end_retrieve_local_documents"
else:
return "retrieve_local_data"
def route_retrieve_documents(state):
sources_to_retrieve = []
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
sources_to_retrieve.append("retrieve_graphs")
if sources_to_retrieve == []:
return END
return sources_to_retrieve
def make_id_dict(values):
return {k:k for k in values}
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
workflow = StateGraph(GraphState)
# Define the node functions
categorize_intent = make_intent_categorization_node(llm)
transform_query = make_query_transform_node(llm)
translate_query = make_translation_node(llm)
answer_chitchat = make_chitchat_node(llm)
answer_ai_impact = make_ai_impact_node(llm)
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
answer_rag = make_rag_node(llm, with_docs=True)
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
# Define the nodes
# workflow.add_node("set_defaults", set_defaults)
workflow.add_node("categorize_intent", categorize_intent)
workflow.add_node("answer_climate", dummy)
workflow.add_node("answer_search", answer_search)
workflow.add_node("transform_query", transform_query)
workflow.add_node("translate_query", translate_query)
workflow.add_node("answer_chitchat", answer_chitchat)
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
workflow.add_node("retrieve_graphs", retrieve_graphs)
# workflow.add_node("retrieve_local_data", retrieve_local_data)
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("answer_rag", answer_rag)
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
# Entry point
workflow.set_entry_point("categorize_intent")
# CONDITIONAL EDGES
workflow.add_conditional_edges(
"categorize_intent",
route_intent,
make_id_dict(["answer_chitchat","answer_climate"])
)
workflow.add_conditional_edges(
"chitchat_categorize_intent",
chitchat_route_intent,
make_id_dict(["retrieve_graphs_chitchat", END])
)
workflow.add_conditional_edges(
"answer_climate",
route_translation,
make_id_dict(["translate_query","transform_query"])
)
workflow.add_conditional_edges(
"answer_search",
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
make_id_dict(["answer_rag","answer_rag_no_docs"])
)
workflow.add_conditional_edges(
"transform_query",
route_retrieve_documents,
make_id_dict(["retrieve_graphs", END])
)
# Define the edges
workflow.add_edge("translate_query", "transform_query")
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
# workflow.add_edge("transform_query", "retrieve_local_data")
# workflow.add_edge("transform_query", END) # TODO remove
workflow.add_edge("retrieve_graphs", END)
workflow.add_edge("answer_rag", END)
workflow.add_edge("answer_rag_no_docs", END)
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
workflow.add_edge("retrieve_graphs_chitchat", END)
# workflow.add_edge("retrieve_local_data", "answer_search")
workflow.add_edge("retrieve_documents", "answer_search")
# Compile
app = workflow.compile()
return app
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2):
"""_summary_
Args:
llm (_type_): _description_
vectorstore_ipcc (_type_): _description_
vectorstore_graphs (_type_): _description_
vectorstore_region (_type_): _description_
reranker (_type_): _description_
version (str): version of the parsed documents (e.g "v4")
threshold_docs (float, optional): _description_. Defaults to 0.2.
Returns:
_type_: _description_
"""
workflow = StateGraph(GraphState)
# Define the node functions
categorize_intent = make_intent_categorization_node(llm)
transform_query = make_query_transform_node(llm)
translate_query = make_translation_node(llm)
answer_chitchat = make_chitchat_node(llm)
answer_ai_impact = make_ai_impact_node(llm)
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version)
answer_rag = make_rag_node(llm, with_docs=True)
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
# Define the nodes
# workflow.add_node("set_defaults", set_defaults)
workflow.add_node("categorize_intent", categorize_intent)
workflow.add_node("answer_climate", dummy)
workflow.add_node("answer_search", answer_search)
# workflow.add_node("end_retrieve_local_documents", dummy)
# workflow.add_node("end_retrieve_IPx_documents", dummy)
workflow.add_node("transform_query", transform_query)
workflow.add_node("translate_query", translate_query)
workflow.add_node("answer_chitchat", answer_chitchat)
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
workflow.add_node("retrieve_graphs", retrieve_graphs)
workflow.add_node("retrieve_local_data", retrieve_local_data)
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
workflow.add_node("retrieve_documents", retrieve_documents)
workflow.add_node("answer_rag", answer_rag)
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
# Entry point
workflow.set_entry_point("categorize_intent")
# CONDITIONAL EDGES
workflow.add_conditional_edges(
"categorize_intent",
route_intent,
make_id_dict(["answer_chitchat","answer_climate"])
)
workflow.add_conditional_edges(
"chitchat_categorize_intent",
chitchat_route_intent,
make_id_dict(["retrieve_graphs_chitchat", END])
)
workflow.add_conditional_edges(
"answer_climate",
route_translation,
make_id_dict(["translate_query","transform_query"])
)
workflow.add_conditional_edges(
"answer_search",
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
make_id_dict(["answer_rag","answer_rag_no_docs"])
)
workflow.add_conditional_edges(
"transform_query",
route_retrieve_documents,
make_id_dict(["retrieve_graphs", END])
)
# Define the edges
workflow.add_edge("translate_query", "transform_query")
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
workflow.add_edge("transform_query", "retrieve_local_data")
# workflow.add_edge("transform_query", END) # TODO remove
workflow.add_edge("retrieve_graphs", END)
workflow.add_edge("answer_rag", END)
workflow.add_edge("answer_rag_no_docs", END)
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
workflow.add_edge("retrieve_graphs_chitchat", END)
workflow.add_edge("retrieve_local_data", "answer_search")
workflow.add_edge("retrieve_documents", "answer_search")
# workflow.add_edge("transform_query", "retrieve_drias_data")
# workflow.add_edge("retrieve_drias_data", END)
# Compile
app = workflow.compile()
return app
def display_graph(app):
display(
Image(
app.get_graph(xray = True).draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)