|
import os |
|
import re |
|
|
|
from typing import Annotated |
|
from typing_extensions import TypedDict |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
from langchain_community.graphs import Neo4jGraph |
|
|
|
from langgraph.graph import StateGraph |
|
from langgraph.graph import add_messages |
|
|
|
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT |
|
from ki_gen.data_retriever import build_data_retriever_graph |
|
from ki_gen.data_processor import build_data_processor_graph |
|
from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState |
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_node(state: State): |
|
""" |
|
This node inserts the plan validation prompt. |
|
""" |
|
prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise. |
|
If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct".""" |
|
output = HumanMessage(content=prompt) |
|
return {"messages" : [output]} |
|
|
|
|
|
|
|
def chatbot_llama(state: State): |
|
llm_llama = ChatGroq(model="llama3-70b-8192") |
|
return {"messages" : [llm_llama.invoke(state["messages"])]} |
|
|
|
def chatbot_mixtral(state: State): |
|
llm_mixtral = ChatGroq(model="mixtral-8x7b-32768") |
|
return {"messages" : [llm_mixtral.invoke(state["messages"])]} |
|
|
|
def chatbot_openai(state: State): |
|
llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") |
|
return {"messages" : [llm_openai.invoke(state["messages"])]} |
|
|
|
chatbots = {"gpt-4o" : chatbot_openai, |
|
"mixtral-8x7b-32768" : chatbot_mixtral, |
|
"llama3-70b-8192" : chatbot_llama |
|
} |
|
|
|
|
|
def parse_plan(state: State): |
|
""" |
|
This node parses the generated plan and writes in the 'store_plan' field of the state |
|
""" |
|
plan = state["messages"][-3].content |
|
store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:] |
|
try: |
|
store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0] |
|
except Exception as e: |
|
print(f"Error while removing <END_OF_PLAN> : {e}") |
|
|
|
return {"store_plan" : store_plan} |
|
|
|
def detail_step(state: State, config: ConfigSchema): |
|
""" |
|
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever. |
|
""" |
|
print("test") |
|
print(state) |
|
|
|
if 'current_plan_step' in state.keys(): |
|
print("all good chief") |
|
else: |
|
state["current_plan_step"] = None |
|
|
|
current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 |
|
if config["configurable"].get("use_detailed_query"): |
|
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan : |
|
Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""") |
|
query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm")) |
|
return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query} |
|
|
|
return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []} |
|
|
|
def get_detailed_query(context : list, model : str = "mixtral-8x7b-32768"): |
|
""" |
|
Simple helper function for the detail_step node |
|
""" |
|
if model == 'gpt-4o': |
|
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/") |
|
else: |
|
llm = ChatGroq(model=model) |
|
return llm.invoke(context) |
|
|
|
def concatenate_data(state: State): |
|
""" |
|
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages |
|
""" |
|
prompt = f"""#########TECHNICAL INFORMATION ############ |
|
{str(state["valid_docs"])} |
|
|
|
########END OF TECHNICAL INFORMATION####### |
|
|
|
Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan : |
|
{state['store_plan'][state['current_plan_step']]} |
|
""" |
|
|
|
return {"messages": [HumanMessage(content=prompt)]} |
|
|
|
|
|
def human_validation(state: HumanValidationState) -> HumanValidationState: |
|
""" |
|
Dummy node to interrupt before |
|
""" |
|
return {'process_steps' : []} |
|
|
|
def generate_ki(state: State): |
|
""" |
|
This node inserts the prompt to begin Key Issues generation |
|
""" |
|
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}") |
|
|
|
prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues : |
|
{state['store_plan'][state['current_plan_step'] + 1]}""" |
|
|
|
return {"messages" : [HumanMessage(content=prompt)]} |
|
|
|
def detail_ki(state: State): |
|
""" |
|
This node inserts the last prompt to detail the generated Key Issues |
|
""" |
|
prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues : |
|
{state['store_plan'][state['current_plan_step'] + 2]}""" |
|
|
|
return {"messages" : [HumanMessage(content=prompt)]} |
|
|
|
|
|
|
|
|
|
|
|
def validate_plan(state: State): |
|
""" |
|
Whether to regenerate the plan or to parse it |
|
""" |
|
if "messages" in state and state["messages"][-1].content in ["My plan is correct.","My plan is correct"]: |
|
return "parse" |
|
return "validate" |
|
|
|
def next_plan_step(state: State, config: ConfigSchema): |
|
""" |
|
Proceed to next plan step (either generate KI or retrieve more data) |
|
""" |
|
if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"): |
|
return "generate_key_issues" |
|
if state["current_plan_step"] == len(state["store_plan"]) - 1: |
|
return "generate_key_issues" |
|
else: |
|
return "detail_step" |
|
|
|
def detail_or_data_retriever(state: State, config: ConfigSchema): |
|
""" |
|
Detail the query to use for data retrieval or not |
|
""" |
|
if config["configurable"].get("use_detailed_query"): |
|
return "chatbot_detail" |
|
else: |
|
return "data_retriever" |
|
|
|
def retrieve_or_process(state: State): |
|
""" |
|
Process the retrieved docs or keep retrieving |
|
""" |
|
if state['human_validated']: |
|
return "process" |
|
return "retrieve" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_planner_graph(memory, config): |
|
""" |
|
Builds the planner graph |
|
""" |
|
graph_builder = StateGraph(State) |
|
|
|
graph_doc_retriever = build_data_retriever_graph(memory) |
|
graph_doc_processor = build_data_processor_graph(memory) |
|
graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]]) |
|
graph_builder.add_node("validate", validate_node) |
|
graph_builder.add_node("chatbot_detail", chatbot_llama) |
|
graph_builder.add_node("parse", parse_plan) |
|
graph_builder.add_node("detail_step", detail_step) |
|
graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState) |
|
graph_builder.add_node("human_validation", human_validation) |
|
graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState) |
|
graph_builder.add_node("concatenate_data", concatenate_data) |
|
graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]]) |
|
graph_builder.add_node("generate_ki", generate_ki) |
|
graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]]) |
|
graph_builder.add_node("detail_ki", detail_ki) |
|
graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]]) |
|
|
|
graph_builder.add_edge("validate", "chatbot_planner") |
|
graph_builder.add_edge("parse", "detail_step") |
|
|
|
|
|
|
|
graph_builder.add_edge("chatbot_detail", "data_retriever") |
|
graph_builder.add_edge("data_retriever", "human_validation") |
|
|
|
|
|
graph_builder.add_edge("data_processor", "concatenate_data") |
|
graph_builder.add_edge("concatenate_data", "chatbot_exec_step") |
|
graph_builder.add_edge("generate_ki", "chatbot_ki") |
|
graph_builder.add_edge("chatbot_ki", "detail_ki") |
|
graph_builder.add_edge("detail_ki", "chatbot_final") |
|
graph_builder.add_edge("chatbot_final", "__end__") |
|
|
|
graph_builder.add_conditional_edges( |
|
"detail_step", |
|
detail_or_data_retriever, |
|
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"} |
|
) |
|
graph_builder.add_conditional_edges( |
|
"human_validation", |
|
retrieve_or_process, |
|
{"retrieve" : "data_retriever", "process" : "data_processor"} |
|
) |
|
graph_builder.add_conditional_edges( |
|
"chatbot_planner", |
|
validate_plan, |
|
{"parse" : "parse", "validate": "validate"} |
|
) |
|
graph_builder.add_conditional_edges( |
|
"chatbot_exec_step", |
|
next_plan_step, |
|
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"} |
|
) |
|
|
|
graph_builder.set_entry_point("chatbot_planner") |
|
graph = graph_builder.compile( |
|
checkpointer=memory, |
|
interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"], |
|
) |
|
return graph |