Spaces:
Sleeping
Sleeping
import os | |
import re | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_community.graphs import Neo4jGraph | |
from langgraph.graph import StateGraph | |
from langgraph.graph import add_messages | |
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT | |
from ki_gen.data_retriever import build_data_retriever_graph | |
from ki_gen.data_processor import build_data_processor_graph | |
from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
########################################################################## | |
###### NODES DEFINITION ###### | |
########################################################################## | |
def validate_node(state: State): | |
""" | |
This node inserts the plan validation prompt. | |
""" | |
prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise. | |
If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct".""" | |
output = HumanMessage(content=prompt) | |
return {"messages" : [output]} | |
def error_chatbot_groq(error, model_name, query): # Pass model_name instead of llm_groq object | |
# Switch API key logic... | |
if os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key"): | |
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key2") | |
elif os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key2"): | |
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key3") | |
else: | |
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key") | |
# Re-initialize the model *after* switching the key | |
try: | |
# Use the model_name passed in | |
llm_groq_retry = ChatGroq(model=model_name) | |
# Pass the original query messages | |
return {"messages" : [llm_groq_retry.invoke(query)]} | |
except Exception as retry_error: | |
# Handle potential error during retry | |
print(f"Error during retry: {retry_error}") | |
# Decide what to return or raise here | |
return {"messages": [SystemMessage(content=f"Failed to process after retry: {retry_error}")]} | |
# Wrappers to call LLMs on the state messsages field | |
def chatbot_llama(state: State): | |
try: | |
llm_llama = ChatGroq(model="llama3-70b-8192") | |
return {"messages" : [llm_llama.invoke(state["messages"])]} | |
except Exception as error: | |
error_chatbot_groq(error,llm_llama,state["messages"]) | |
def chatbot_mixtral(state: State): | |
print(state) | |
llm_mixtral = ChatGroq(model="deepseek-r1-distill-llama-70b") | |
print(llm_mixtral) | |
return {"messages" : [llm_mixtral.invoke(state["messages"])]} | |
# except Exception as error: | |
# error_chatbot_groq(error,llm_mixtral,state["messages"]) | |
def chatbot_openai(state: State): | |
llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") | |
return {"messages" : [llm_openai.invoke(state["messages"])]} | |
chatbots = {"gpt-4o" : chatbot_openai, | |
"deepseek-r1-distill-llama-70b" : chatbot_mixtral, | |
"llama3-70b-8192" : chatbot_llama | |
} | |
def parse_plan(state: State): | |
""" | |
This node parses the generated plan and writes in the 'store_plan' field of the state | |
""" | |
plan = state["messages"][-3].content | |
store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:] | |
try: | |
store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0] | |
except Exception as e: | |
print(f"Error while removing <END_OF_PLAN> : {e}") | |
return {"store_plan" : store_plan} | |
def detail_step(state: State, config: ConfigSchema): | |
""" | |
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever. | |
""" | |
print("test") | |
print(state) | |
if 'current_plan_step' in state.keys(): | |
print("all good chief") | |
else: | |
state["current_plan_step"] = None | |
current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end | |
if config["configurable"].get("use_detailed_query"): | |
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan : | |
Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""") | |
query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm")) | |
return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query} | |
return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []} | |
def get_detailed_query(context : list, model : str = "deepseek-r1-distill-llama-70b"): | |
""" | |
Simple helper function for the detail_step node | |
""" | |
if model == 'gpt-4o': | |
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/") | |
else: | |
llm = ChatGroq(model=model) | |
return llm.invoke(context) | |
def concatenate_data(state: State): | |
""" | |
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages | |
""" | |
prompt = f"""#########TECHNICAL INFORMATION ############ | |
{str(state["valid_docs"])} | |
########END OF TECHNICAL INFORMATION####### | |
Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan : | |
{state['store_plan'][state['current_plan_step']]} | |
""" | |
return {"messages": [HumanMessage(content=prompt)]} | |
def human_validation(state: HumanValidationState) -> HumanValidationState: | |
""" | |
Dummy node to interrupt before | |
""" | |
return {'process_steps' : []} | |
def generate_ki(state: State): | |
""" | |
This node inserts the prompt to begin Key Issues generation | |
""" | |
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}") | |
prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
{state['store_plan'][state['current_plan_step'] + 1]}""" | |
return {"messages" : [HumanMessage(content=prompt)]} | |
def detail_ki(state: State): | |
""" | |
This node inserts the last prompt to detail the generated Key Issues | |
""" | |
prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
{state['store_plan'][state['current_plan_step'] + 2]}""" | |
return {"messages" : [HumanMessage(content=prompt)]} | |
########################################################################## | |
###### CONDITIONAL EDGE FUNCTIONS ###### | |
########################################################################## | |
def validate_plan(state: State): | |
""" | |
Whether to regenerate the plan or to parse it | |
""" | |
if "messages" in state and "My plan is correct" in state["messages"][-1].content: | |
return "parse" | |
return "validate" | |
def next_plan_step(state: State, config: ConfigSchema): | |
""" | |
Proceed to next plan step (either generate KI or retrieve more data) | |
""" | |
if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"): | |
return "generate_key_issues" | |
if state["current_plan_step"] == len(state["store_plan"]) - 1: | |
return "generate_key_issues" | |
else: | |
return "detail_step" | |
def detail_or_data_retriever(state: State, config: ConfigSchema): | |
""" | |
Detail the query to use for data retrieval or not | |
""" | |
if config["configurable"].get("use_detailed_query"): | |
return "chatbot_detail" | |
else: | |
return "data_retriever" | |
def retrieve_or_process(state: State): | |
""" | |
Process the retrieved docs or keep retrieving | |
""" | |
if state['human_validated']: | |
return "process" | |
return "retrieve" | |
# while True: | |
# user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ") | |
# if user_input.lower() == "y": | |
# return "retrieve" | |
# if not user_input or user_input.lower() == "n": | |
# return "process" | |
# print("Please answer with 'y' or 'n'.\n") | |
def build_planner_graph(memory, config): | |
""" | |
Builds the planner graph | |
""" | |
graph_builder = StateGraph(State) | |
graph_doc_retriever = build_data_retriever_graph(memory) | |
graph_doc_processor = build_data_processor_graph(memory) | |
graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]]) | |
graph_builder.add_node("validate", validate_node) | |
graph_builder.add_node("chatbot_detail", chatbot_llama) | |
graph_builder.add_node("parse", parse_plan) | |
graph_builder.add_node("detail_step", detail_step) | |
graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState) | |
graph_builder.add_node("human_validation", human_validation) | |
graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState) | |
graph_builder.add_node("concatenate_data", concatenate_data) | |
graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]]) | |
graph_builder.add_node("generate_ki", generate_ki) | |
graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]]) | |
graph_builder.add_node("detail_ki", detail_ki) | |
graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]]) | |
graph_builder.add_edge("validate", "chatbot_planner") | |
graph_builder.add_edge("parse", "detail_step") | |
# graph_builder.add_edge("detail_step", "chatbot2") | |
graph_builder.add_edge("chatbot_detail", "data_retriever") | |
graph_builder.add_edge("data_retriever", "human_validation") | |
graph_builder.add_edge("data_processor", "concatenate_data") | |
graph_builder.add_edge("concatenate_data", "chatbot_exec_step") | |
graph_builder.add_edge("generate_ki", "chatbot_ki") | |
graph_builder.add_edge("chatbot_ki", "detail_ki") | |
graph_builder.add_edge("detail_ki", "chatbot_final") | |
graph_builder.add_edge("chatbot_final", "__end__") | |
graph_builder.add_conditional_edges( | |
"detail_step", | |
detail_or_data_retriever, | |
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"} | |
) | |
graph_builder.add_conditional_edges( | |
"human_validation", | |
retrieve_or_process, | |
{"retrieve" : "data_retriever", "process" : "data_processor"} | |
) | |
graph_builder.add_conditional_edges( | |
"chatbot_planner", | |
validate_plan, | |
{"parse" : "parse", "validate": "validate"} | |
) | |
graph_builder.add_conditional_edges( | |
"chatbot_exec_step", | |
next_plan_step, | |
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"} | |
) | |
graph_builder.set_entry_point("chatbot_planner") | |
graph = graph_builder.compile( | |
checkpointer=memory, | |
interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"], | |
) | |
return graph |