|
import sys |
|
import os |
|
from contextlib import contextmanager |
|
|
|
from langchain.schema import Document |
|
from langgraph.graph import END, StateGraph |
|
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod |
|
|
|
from typing_extensions import TypedDict |
|
from typing import List, Dict |
|
|
|
import operator |
|
from typing import Annotated |
|
import pandas as pd |
|
from IPython.display import display, HTML, Image |
|
|
|
from .chains.answer_chitchat import make_chitchat_node |
|
from .chains.answer_ai_impact import make_ai_impact_node |
|
from .chains.query_transformation import make_query_transform_node |
|
from .chains.translation import make_translation_node |
|
from .chains.intent_categorization import make_intent_categorization_node |
|
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node |
|
from .chains.answer_rag import make_rag_node |
|
from .chains.graph_retriever import make_graph_retriever_node |
|
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node |
|
|
|
|
|
class GraphState(TypedDict): |
|
""" |
|
Represents the state of our graph. |
|
""" |
|
user_input : str |
|
language : str |
|
intent : str |
|
search_graphs_chitchat : bool |
|
query: str |
|
questions_list : List[dict] |
|
handled_questions_index : Annotated[list[int], operator.add] |
|
n_questions : int |
|
answer: str |
|
audience: str = "experts" |
|
sources_input: List[str] = ["IPCC","IPBES"] |
|
relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"] |
|
sources_auto: bool = True |
|
min_year: int = 1960 |
|
max_year: int = None |
|
documents: Annotated[List[Document], operator.add] |
|
related_contents : Annotated[List[Document], operator.add] |
|
recommended_content : List[Document] |
|
search_only : bool = False |
|
reports : List[str] = [] |
|
|
|
def dummy(state): |
|
return |
|
|
|
def search(state): |
|
return |
|
|
|
def answer_search(state): |
|
return |
|
|
|
def route_intent(state): |
|
intent = state["intent"] |
|
if intent in ["chitchat","esg"]: |
|
return "answer_chitchat" |
|
|
|
|
|
else: |
|
|
|
return "answer_climate" |
|
|
|
def chitchat_route_intent(state): |
|
intent = state["search_graphs_chitchat"] |
|
if intent is True: |
|
return END |
|
elif intent is False: |
|
return END |
|
|
|
def route_translation(state): |
|
if state["language"].lower() == "english": |
|
return "transform_query" |
|
else: |
|
return "transform_query" |
|
|
|
|
|
|
|
def route_based_on_relevant_docs(state,threshold_docs=0.2): |
|
docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs] |
|
print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"]) |
|
if len(docs) > 0: |
|
return "answer_rag" |
|
else: |
|
return "answer_rag_no_docs" |
|
|
|
def route_continue_retrieve_documents(state): |
|
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"] |
|
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx) |
|
if questions_ipx_finished: |
|
return "end_retrieve_IPx_documents" |
|
else: |
|
return "retrieve_documents" |
|
|
|
def route_continue_retrieve_local_documents(state): |
|
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] |
|
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc) |
|
|
|
|
|
if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]): |
|
return "end_retrieve_local_documents" |
|
else: |
|
return "retrieve_local_data" |
|
|
|
def route_retrieve_documents(state): |
|
sources_to_retrieve = [] |
|
|
|
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] : |
|
sources_to_retrieve.append("retrieve_graphs") |
|
|
|
if sources_to_retrieve == []: |
|
return END |
|
return sources_to_retrieve |
|
|
|
def make_id_dict(values): |
|
return {k:k for k in values} |
|
|
|
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2): |
|
|
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
categorize_intent = make_intent_categorization_node(llm) |
|
transform_query = make_query_transform_node(llm) |
|
translate_query = make_translation_node(llm) |
|
answer_chitchat = make_chitchat_node(llm) |
|
answer_ai_impact = make_ai_impact_node(llm) |
|
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm) |
|
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker) |
|
|
|
answer_rag = make_rag_node(llm, with_docs=True) |
|
answer_rag_no_docs = make_rag_node(llm, with_docs=False) |
|
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm) |
|
|
|
|
|
|
|
workflow.add_node("categorize_intent", categorize_intent) |
|
workflow.add_node("answer_climate", dummy) |
|
workflow.add_node("answer_search", answer_search) |
|
workflow.add_node("transform_query", transform_query) |
|
workflow.add_node("translate_query", translate_query) |
|
workflow.add_node("answer_chitchat", answer_chitchat) |
|
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent) |
|
workflow.add_node("retrieve_graphs", retrieve_graphs) |
|
|
|
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs) |
|
workflow.add_node("retrieve_documents", retrieve_documents) |
|
workflow.add_node("answer_rag", answer_rag) |
|
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs) |
|
|
|
|
|
workflow.set_entry_point("categorize_intent") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"categorize_intent", |
|
route_intent, |
|
make_id_dict(["answer_chitchat","answer_climate"]) |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"chitchat_categorize_intent", |
|
chitchat_route_intent, |
|
make_id_dict(["retrieve_graphs_chitchat", END]) |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"answer_climate", |
|
route_translation, |
|
make_id_dict(["translate_query","transform_query"]) |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"answer_search", |
|
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs), |
|
make_id_dict(["answer_rag","answer_rag_no_docs"]) |
|
) |
|
workflow.add_conditional_edges( |
|
"transform_query", |
|
route_retrieve_documents, |
|
make_id_dict(["retrieve_graphs", END]) |
|
) |
|
|
|
|
|
workflow.add_edge("translate_query", "transform_query") |
|
workflow.add_edge("transform_query", "retrieve_documents") |
|
|
|
|
|
|
|
workflow.add_edge("retrieve_graphs", END) |
|
workflow.add_edge("answer_rag", END) |
|
workflow.add_edge("answer_rag_no_docs", END) |
|
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent") |
|
workflow.add_edge("retrieve_graphs_chitchat", END) |
|
|
|
|
|
workflow.add_edge("retrieve_documents", "answer_search") |
|
|
|
|
|
app = workflow.compile() |
|
return app |
|
|
|
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2): |
|
"""_summary_ |
|
|
|
Args: |
|
llm (_type_): _description_ |
|
vectorstore_ipcc (_type_): _description_ |
|
vectorstore_graphs (_type_): _description_ |
|
vectorstore_region (_type_): _description_ |
|
reranker (_type_): _description_ |
|
version (str): version of the parsed documents (e.g "v4") |
|
threshold_docs (float, optional): _description_. Defaults to 0.2. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
|
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
categorize_intent = make_intent_categorization_node(llm) |
|
transform_query = make_query_transform_node(llm) |
|
translate_query = make_translation_node(llm) |
|
answer_chitchat = make_chitchat_node(llm) |
|
answer_ai_impact = make_ai_impact_node(llm) |
|
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm) |
|
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker) |
|
|
|
retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version) |
|
answer_rag = make_rag_node(llm, with_docs=True) |
|
answer_rag_no_docs = make_rag_node(llm, with_docs=False) |
|
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm) |
|
|
|
|
|
|
|
workflow.add_node("categorize_intent", categorize_intent) |
|
workflow.add_node("answer_climate", dummy) |
|
workflow.add_node("answer_search", answer_search) |
|
|
|
|
|
workflow.add_node("transform_query", transform_query) |
|
workflow.add_node("translate_query", translate_query) |
|
workflow.add_node("answer_chitchat", answer_chitchat) |
|
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent) |
|
workflow.add_node("retrieve_graphs", retrieve_graphs) |
|
workflow.add_node("retrieve_local_data", retrieve_local_data) |
|
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs) |
|
workflow.add_node("retrieve_documents", retrieve_documents) |
|
workflow.add_node("answer_rag", answer_rag) |
|
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs) |
|
|
|
|
|
workflow.set_entry_point("categorize_intent") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"categorize_intent", |
|
route_intent, |
|
make_id_dict(["answer_chitchat","answer_climate"]) |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"chitchat_categorize_intent", |
|
chitchat_route_intent, |
|
make_id_dict(["retrieve_graphs_chitchat", END]) |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"answer_climate", |
|
route_translation, |
|
make_id_dict(["translate_query","transform_query"]) |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"answer_search", |
|
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs), |
|
make_id_dict(["answer_rag","answer_rag_no_docs"]) |
|
) |
|
workflow.add_conditional_edges( |
|
"transform_query", |
|
route_retrieve_documents, |
|
make_id_dict(["retrieve_graphs", END]) |
|
) |
|
|
|
|
|
workflow.add_edge("translate_query", "transform_query") |
|
workflow.add_edge("transform_query", "retrieve_documents") |
|
workflow.add_edge("transform_query", "retrieve_local_data") |
|
|
|
|
|
workflow.add_edge("retrieve_graphs", END) |
|
workflow.add_edge("answer_rag", END) |
|
workflow.add_edge("answer_rag_no_docs", END) |
|
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent") |
|
workflow.add_edge("retrieve_graphs_chitchat", END) |
|
|
|
workflow.add_edge("retrieve_local_data", "answer_search") |
|
workflow.add_edge("retrieve_documents", "answer_search") |
|
|
|
|
|
|
|
|
|
|
|
|
|
app = workflow.compile() |
|
return app |
|
|
|
|
|
|
|
|
|
def display_graph(app): |
|
|
|
display( |
|
Image( |
|
app.get_graph(xray = True).draw_mermaid_png( |
|
draw_method=MermaidDrawMethod.API, |
|
) |
|
) |
|
) |
|
|