Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- ki_gen/data_processor.py +185 -0
- ki_gen/data_retriever.py +406 -0
- ki_gen/planner.py +282 -0
- ki_gen/prompts.py +155 -0
- ki_gen/utils.py +152 -0
ki_gen/data_processor.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain_groq import ChatGroq
|
8 |
+
from langgraph.graph import StateGraph
|
9 |
+
from llmlingua import PromptCompressor
|
10 |
+
|
11 |
+
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
|
12 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
# compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
|
18 |
+
|
19 |
+
## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
|
20 |
+
## Before that, you need to pip install optimum auto-gptq
|
21 |
+
# llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
# Requires ~2GB of RAM
|
26 |
+
def get_llm_lingua(compress_method:str = "llm_lingua2"):
|
27 |
+
|
28 |
+
# Requires ~2GB memory
|
29 |
+
if compress_method == "llm_lingua2":
|
30 |
+
llm_lingua2 = PromptCompressor(
|
31 |
+
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
|
32 |
+
use_llmlingua2=True,
|
33 |
+
device_map="cpu"
|
34 |
+
)
|
35 |
+
return llm_lingua2
|
36 |
+
|
37 |
+
# Requires ~8GB memory
|
38 |
+
elif compress_method == "llm_lingua":
|
39 |
+
llm_lingua = PromptCompressor(
|
40 |
+
model_name="microsoft/phi-2",
|
41 |
+
device_map="cpu"
|
42 |
+
)
|
43 |
+
return llm_lingua
|
44 |
+
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
def compress(state: DocProcessorState, config: ConfigSchema):
|
49 |
+
"""
|
50 |
+
This node compresses last processing result for each doc using llm_lingua
|
51 |
+
"""
|
52 |
+
doc_process_histories = state["docs_in_processing"]
|
53 |
+
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
|
54 |
+
for doc_process_history in doc_process_histories:
|
55 |
+
doc_process_history.append(llm_lingua.compress_prompt(
|
56 |
+
doc = str(doc_process_history[-1]),
|
57 |
+
rate=config["configurable"].get("compress_rate") or 0.33,
|
58 |
+
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
|
59 |
+
)["compressed_prompt"]
|
60 |
+
)
|
61 |
+
|
62 |
+
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
63 |
+
|
64 |
+
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
|
65 |
+
"""
|
66 |
+
This node summarizes all docs in state["valid_docs"]
|
67 |
+
"""
|
68 |
+
|
69 |
+
prompt = """You are a 3GPP standardization expert.
|
70 |
+
Summarize the provided document in simple technical English for other experts in the field.
|
71 |
+
|
72 |
+
Document:
|
73 |
+
{document}"""
|
74 |
+
sysmsg = ChatPromptTemplate.from_messages([
|
75 |
+
("system", prompt)
|
76 |
+
])
|
77 |
+
model = config["configurable"].get("summarize_model") or "deepseek-r1-distill-llama-70b"
|
78 |
+
doc_process_histories = state["docs_in_processing"]
|
79 |
+
if model == "gpt-4o":
|
80 |
+
llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
81 |
+
else:
|
82 |
+
llm_summarize = ChatGroq(model=model)
|
83 |
+
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
|
84 |
+
|
85 |
+
for doc_process_history in doc_process_histories:
|
86 |
+
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
|
87 |
+
|
88 |
+
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
|
89 |
+
|
90 |
+
def custom_process(state: DocProcessorState):
|
91 |
+
"""
|
92 |
+
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
|
93 |
+
processing_model : the LLM which will perform the processing
|
94 |
+
context : the previous processing results to send as context to the LLM
|
95 |
+
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
|
96 |
+
"""
|
97 |
+
|
98 |
+
processing_params = state["process_steps"][state["current_process_step"]]
|
99 |
+
model = processing_params.get("processing_model") or "deepseek-r1-distill-llama-70b"
|
100 |
+
user_prompt = processing_params["prompt"]
|
101 |
+
context = processing_params.get("context") or [0]
|
102 |
+
doc_process_histories = state["docs_in_processing"]
|
103 |
+
if not isinstance(context, list):
|
104 |
+
context = [context]
|
105 |
+
|
106 |
+
processing_chain = get_model(model=model) | StrOutputParser()
|
107 |
+
|
108 |
+
for doc_process_history in doc_process_histories:
|
109 |
+
context_str = ""
|
110 |
+
for i, context_element in enumerate(context):
|
111 |
+
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
|
112 |
+
doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
|
113 |
+
|
114 |
+
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
115 |
+
|
116 |
+
def final(state: DocProcessorState):
|
117 |
+
"""
|
118 |
+
A node to store the final results of processing in the 'valid_docs' field
|
119 |
+
"""
|
120 |
+
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
|
121 |
+
|
122 |
+
# TODO : remove this node and use conditional entry point instead
|
123 |
+
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
|
124 |
+
"""
|
125 |
+
Dummy node
|
126 |
+
"""
|
127 |
+
# if not process_steps:
|
128 |
+
# process_steps = eval(input("Enter processing steps: "))
|
129 |
+
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
|
130 |
+
|
131 |
+
|
132 |
+
def next_processor_step(state: DocProcessorState):
|
133 |
+
"""
|
134 |
+
Conditional edge function to go to next processing step
|
135 |
+
"""
|
136 |
+
process_steps = state["process_steps"]
|
137 |
+
if state["current_process_step"] < len(process_steps):
|
138 |
+
step = process_steps[state["current_process_step"]]
|
139 |
+
if isinstance(step, dict):
|
140 |
+
step = "custom"
|
141 |
+
else:
|
142 |
+
step = "final"
|
143 |
+
|
144 |
+
return step
|
145 |
+
|
146 |
+
|
147 |
+
def build_data_processor_graph(memory):
|
148 |
+
"""
|
149 |
+
Builds the data processor graph
|
150 |
+
"""
|
151 |
+
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
152 |
+
|
153 |
+
graph_builder_doc_processor = StateGraph(DocProcessorState)
|
154 |
+
|
155 |
+
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
|
156 |
+
graph_builder_doc_processor.add_node("summarize", summarize_docs)
|
157 |
+
graph_builder_doc_processor.add_node("compress", compress)
|
158 |
+
graph_builder_doc_processor.add_node("custom", custom_process)
|
159 |
+
graph_builder_doc_processor.add_node("final", final)
|
160 |
+
|
161 |
+
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
|
162 |
+
graph_builder_doc_processor.add_conditional_edges(
|
163 |
+
"get_process_steps",
|
164 |
+
next_processor_step,
|
165 |
+
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
|
166 |
+
)
|
167 |
+
graph_builder_doc_processor.add_conditional_edges(
|
168 |
+
"summarize",
|
169 |
+
next_processor_step,
|
170 |
+
{"compress" : "compress", "final": "final", "custom" : "custom"}
|
171 |
+
)
|
172 |
+
graph_builder_doc_processor.add_conditional_edges(
|
173 |
+
"compress",
|
174 |
+
next_processor_step,
|
175 |
+
{"summarize" : "summarize", "final": "final", "custom" : "custom"}
|
176 |
+
)
|
177 |
+
graph_builder_doc_processor.add_conditional_edges(
|
178 |
+
"custom",
|
179 |
+
next_processor_step,
|
180 |
+
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
|
181 |
+
)
|
182 |
+
graph_builder_doc_processor.add_edge("final", "__end__")
|
183 |
+
|
184 |
+
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
|
185 |
+
return graph_doc_processor
|
ki_gen/data_retriever.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from random import shuffle, sample
|
7 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
8 |
+
|
9 |
+
from langchain_groq import ChatGroq
|
10 |
+
from langchain_openai import ChatOpenAI
|
11 |
+
from langchain_core.messages import HumanMessage
|
12 |
+
from langchain_community.graphs import Neo4jGraph
|
13 |
+
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
14 |
+
from langchain_core.output_parsers import StrOutputParser
|
15 |
+
from langchain_core.prompts import ChatPromptTemplate
|
16 |
+
from langchain_core.pydantic_v1 import Field
|
17 |
+
from pydantic import BaseModel
|
18 |
+
from langchain_groq import ChatGroq
|
19 |
+
|
20 |
+
from langgraph.graph import StateGraph
|
21 |
+
|
22 |
+
from llmlingua import PromptCompressor
|
23 |
+
|
24 |
+
from ki_gen.prompts import (
|
25 |
+
CYPHER_GENERATION_PROMPT,
|
26 |
+
CONCEPT_SELECTION_PROMPT,
|
27 |
+
BINARY_GRADER_PROMPT,
|
28 |
+
SCORE_GRADER_PROMPT,
|
29 |
+
RELEVANT_CONCEPTS_PROMPT,
|
30 |
+
)
|
31 |
+
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def extract_cypher(text: str) -> str:
|
37 |
+
"""Extract Cypher code from a text.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
text: Text to extract Cypher code from.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Cypher code extracted from the text.
|
44 |
+
"""
|
45 |
+
# The pattern to find Cypher code enclosed in triple backticks
|
46 |
+
pattern_1 = r"```cypher\n(.*?)```"
|
47 |
+
pattern_2 = r"```\n(.*?)```"
|
48 |
+
|
49 |
+
# Find all matches in the input text
|
50 |
+
matches_1 = re.findall(pattern_1, text, re.DOTALL)
|
51 |
+
matches_2 = re.findall(pattern_2, text, re.DOTALL)
|
52 |
+
return [
|
53 |
+
matches_1[0] if matches_1 else text,
|
54 |
+
matches_2[0] if matches_2 else text,
|
55 |
+
text
|
56 |
+
]
|
57 |
+
|
58 |
+
def get_cypher_gen_chain(model: str = "deepseek-r1-distill-llama-70b"):
|
59 |
+
"""
|
60 |
+
Returns cypher gen chain using specified model for generation
|
61 |
+
This is used when the 'auto' cypher generation method has been configured
|
62 |
+
"""
|
63 |
+
|
64 |
+
if model=="openai":
|
65 |
+
llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
66 |
+
else:
|
67 |
+
llm_cypher_gen = ChatGroq(model = "deepseek-r1-distill-llama-70b")
|
68 |
+
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
|
69 |
+
return cypher_gen_chain
|
70 |
+
|
71 |
+
def get_concept_selection_chain(model: str = "deepseek-r1-distill-llama-70b"):
|
72 |
+
"""
|
73 |
+
Returns a chain to select the most relevant topic using specified model for generation.
|
74 |
+
This is used when the 'guided' cypher generation method has been configured
|
75 |
+
"""
|
76 |
+
|
77 |
+
if model == "openai":
|
78 |
+
llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
79 |
+
else:
|
80 |
+
llm_topic_selection = ChatGroq(model="deepseek-r1-distill-llama-70b")
|
81 |
+
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
|
82 |
+
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
|
83 |
+
return topic_selection_chain
|
84 |
+
|
85 |
+
def get_concepts(graph: Neo4jGraph):
|
86 |
+
concept_cypher = "MATCH (c:Concept) return c"
|
87 |
+
if isinstance(graph, Neo4jGraph):
|
88 |
+
concepts = graph.query(concept_cypher)
|
89 |
+
else:
|
90 |
+
user_input = input("Topics : ")
|
91 |
+
concepts = eval(user_input)
|
92 |
+
|
93 |
+
concepts_name = [concept['c']['name'] for concept in concepts]
|
94 |
+
return concepts_name
|
95 |
+
|
96 |
+
def get_related_concepts(graph: Neo4jGraph, question: str):
|
97 |
+
concepts = get_concepts(graph)
|
98 |
+
llm = get_model()
|
99 |
+
print(f"this is the llm variable : {llm}")
|
100 |
+
def parse_answer(llm_answer : str):
|
101 |
+
try:
|
102 |
+
print(f"This the llm_answer : {llm_answer}")
|
103 |
+
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
|
104 |
+
except:
|
105 |
+
return "No concept"
|
106 |
+
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
|
107 |
+
|
108 |
+
print(f"This is the question of the user : {question}")
|
109 |
+
print(f"This is the concepts of the user : {concepts}")
|
110 |
+
|
111 |
+
|
112 |
+
#groq.APIStatusError: Error code: 413 - {'error': {'message': 'Request too large for model `deepseek-r1-distill-llama-70b` in organization `org_01j6xywkndffv96m3wgh81jm49` on tokens per minute
|
113 |
+
# (TPM): Limit 5000, Requested 17099, please reduce your message size and try again. Visit https://console.groq.com/docs/rate-limits for more information.',
|
114 |
+
# 'type': 'tokens', 'code': 'rate_limit_exceeded'}}
|
115 |
+
|
116 |
+
try:
|
117 |
+
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
118 |
+
print(f"related_concepts_raw : {related_concepts_raw}")
|
119 |
+
except Exception as e:
|
120 |
+
if e.status_code == 413:
|
121 |
+
msg = e.body["error"]["message"]
|
122 |
+
print(f"question is : {question}")
|
123 |
+
print(type(question))
|
124 |
+
error_question = ["user_query", question]
|
125 |
+
related_concepts_raw = error_concept_groq(msg,concepts,related_concepts_chain,error_question)
|
126 |
+
pass
|
127 |
+
|
128 |
+
# We clean up the list we received from the LLM in case there were some hallucinations
|
129 |
+
related_concepts_cleaned = []
|
130 |
+
for related_concept in related_concepts_raw:
|
131 |
+
# If the concept returned from the LLM is in the list we keep it
|
132 |
+
if related_concept in concepts:
|
133 |
+
related_concepts_cleaned.append(related_concept)
|
134 |
+
else:
|
135 |
+
# The LLM sometimes only forgets a few words from the concept name
|
136 |
+
# We check if the generated concept is a substring of an existing one and if it is the case add it to the list
|
137 |
+
for concept in concepts:
|
138 |
+
if related_concept in concept:
|
139 |
+
related_concepts_cleaned.append(concept)
|
140 |
+
break
|
141 |
+
|
142 |
+
# TODO : Add concepts found via similarity search
|
143 |
+
return related_concepts_cleaned
|
144 |
+
|
145 |
+
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
|
146 |
+
concept_string = ""
|
147 |
+
for concept in concept_list:
|
148 |
+
concept_description_query = f"""
|
149 |
+
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
|
150 |
+
"""
|
151 |
+
concept_description = graph.query(concept_description_query)[0]['c.description']
|
152 |
+
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
|
153 |
+
return concept_string
|
154 |
+
|
155 |
+
def get_global_concepts(graph: Neo4jGraph):
|
156 |
+
concept_cypher = "MATCH (gc:GlobalConcept) return gc"
|
157 |
+
if isinstance(graph, Neo4jGraph):
|
158 |
+
concepts = graph.query(concept_cypher)
|
159 |
+
else:
|
160 |
+
user_input = input("Topics : ")
|
161 |
+
concepts = eval(user_input)
|
162 |
+
|
163 |
+
concepts_name = [concept['gc']['name'] for concept in concepts]
|
164 |
+
return concepts_name
|
165 |
+
|
166 |
+
def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
167 |
+
"""
|
168 |
+
The node where the cypher is generated
|
169 |
+
"""
|
170 |
+
graph = config["configurable"].get("graph")
|
171 |
+
question = state['query']
|
172 |
+
related_concepts = get_related_concepts(graph, question)
|
173 |
+
cyphers = []
|
174 |
+
|
175 |
+
if config["configurable"].get("cypher_gen_method") == 'auto':
|
176 |
+
cypher_gen_chain = get_cypher_gen_chain()
|
177 |
+
cyphers = cypher_gen_chain.invoke({
|
178 |
+
"schema": graph.schema,
|
179 |
+
"question": question,
|
180 |
+
"concepts": related_concepts
|
181 |
+
})
|
182 |
+
|
183 |
+
try :
|
184 |
+
|
185 |
+
if config["configurable"].get("cypher_gen_method") == 'guided':
|
186 |
+
concept_selection_chain = get_concept_selection_chain()
|
187 |
+
print(f"Concept selection chain is : {concept_selection_chain}")
|
188 |
+
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
189 |
+
print(f"Selected topic are : {selected_topic}")
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
error_question = ["question", question]
|
193 |
+
selected_topic = error_concept_groq(e.body["error"]["message"],get_concepts(graph),concept_selection_chain,error_question)
|
194 |
+
pass
|
195 |
+
|
196 |
+
if config["configurable"].get("cypher_gen_method") == 'guided':
|
197 |
+
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
|
198 |
+
print(f"Cyphers are : {cyphers}")
|
199 |
+
|
200 |
+
if config["configurable"].get("validate_cypher"):
|
201 |
+
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
|
202 |
+
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
203 |
+
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
204 |
+
|
205 |
+
return {"cyphers" : cyphers}
|
206 |
+
|
207 |
+
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
208 |
+
"""
|
209 |
+
Helper function used when the 'guided' cypher generation method has been configured
|
210 |
+
"""
|
211 |
+
|
212 |
+
print(f"L.176 PLAN STEP : {plan_step}")
|
213 |
+
cypher_el = "(n) return n.title, n.description"
|
214 |
+
match plan_step:
|
215 |
+
case 0:
|
216 |
+
cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description"
|
217 |
+
case 1:
|
218 |
+
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
|
219 |
+
case 2:
|
220 |
+
cypher_el = "(ki:KeyIssue) RETURN ki.description"
|
221 |
+
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
|
222 |
+
|
223 |
+
def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
224 |
+
"""
|
225 |
+
This node retrieves docs from the graph using the generated cypher
|
226 |
+
"""
|
227 |
+
graph = config["configurable"].get("graph")
|
228 |
+
output = []
|
229 |
+
if graph is not None:
|
230 |
+
for cypher in state["cyphers"]:
|
231 |
+
try:
|
232 |
+
output = graph.query(cypher)
|
233 |
+
break
|
234 |
+
except Exception as e:
|
235 |
+
print("Failed to retrieve docs : {e}")
|
236 |
+
|
237 |
+
# Clean up the docs we received as there may be duplicates depending on the cypher query
|
238 |
+
all_docs = []
|
239 |
+
for doc in output:
|
240 |
+
unwinded_doc = {}
|
241 |
+
for key in doc:
|
242 |
+
if isinstance(doc[key], dict):
|
243 |
+
all_docs.append(doc[key])
|
244 |
+
else:
|
245 |
+
unwinded_doc.update({key: doc[key]})
|
246 |
+
if unwinded_doc:
|
247 |
+
all_docs.append(unwinded_doc)
|
248 |
+
|
249 |
+
|
250 |
+
filtered_docs = []
|
251 |
+
for doc in all_docs:
|
252 |
+
if doc not in filtered_docs:
|
253 |
+
filtered_docs.append(doc)
|
254 |
+
|
255 |
+
return {"docs": filtered_docs}
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
# Data model
|
262 |
+
class GradeDocumentsBinary(BaseModel):
|
263 |
+
"""Binary score for relevance check on retrieved documents."""
|
264 |
+
|
265 |
+
binary_score: str = Field(
|
266 |
+
description="Documents are relevant to the question, 'yes' or 'no'"
|
267 |
+
)
|
268 |
+
|
269 |
+
# LLM with function call
|
270 |
+
# llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
|
271 |
+
|
272 |
+
def get_binary_grader(model="deepseek-r1-distill-llama-70b"):
|
273 |
+
"""
|
274 |
+
Returns a binary grader to evaluate relevance of documents using specified model for generation
|
275 |
+
This is used when the 'binary' evaluation method has been configured
|
276 |
+
"""
|
277 |
+
|
278 |
+
|
279 |
+
if model == "gpt-4o":
|
280 |
+
llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
|
281 |
+
else:
|
282 |
+
llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
|
283 |
+
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
|
284 |
+
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
|
285 |
+
return retrieval_grader_binary
|
286 |
+
|
287 |
+
|
288 |
+
class GradeDocumentsScore(BaseModel):
|
289 |
+
"""Score for relevance check on retrieved documents."""
|
290 |
+
|
291 |
+
score: float = Field(
|
292 |
+
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
|
293 |
+
)
|
294 |
+
|
295 |
+
def get_score_grader(model="deepseek-r1-distill-llama-70b"):
|
296 |
+
"""
|
297 |
+
Returns a score grader to evaluate relevance of documents using specified model for generation
|
298 |
+
This is used when the 'score' evaluation method has been configured
|
299 |
+
"""
|
300 |
+
if model == "gpt-4o":
|
301 |
+
llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
|
302 |
+
else:
|
303 |
+
llm_grader_score = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature = 0)
|
304 |
+
structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
|
305 |
+
retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
|
306 |
+
return retrieval_grader_score
|
307 |
+
|
308 |
+
|
309 |
+
def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="deepseek-r1-distill-llama-70b"):
|
310 |
+
'''
|
311 |
+
doc : the document to evaluate
|
312 |
+
query : the query to which to doc shoud be relevant
|
313 |
+
method : "binary" or "score"
|
314 |
+
threshold : for "score" method, score above which a doc is considered relevant
|
315 |
+
'''
|
316 |
+
if method == "binary":
|
317 |
+
retrieval_grader_binary = get_binary_grader(model=eval_model)
|
318 |
+
return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
|
319 |
+
elif method == "score":
|
320 |
+
retrieval_grader_score = get_score_grader(model=eval_model)
|
321 |
+
score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
|
322 |
+
if score is not None:
|
323 |
+
return score if score >= threshold else 0
|
324 |
+
else:
|
325 |
+
# Couldn't parse score, marking document as relevant by default
|
326 |
+
return 1
|
327 |
+
else:
|
328 |
+
raise ValueError("Invalid method")
|
329 |
+
|
330 |
+
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
331 |
+
"""
|
332 |
+
This node performs evaluation of the retrieved docs and
|
333 |
+
"""
|
334 |
+
|
335 |
+
eval_method = config["configurable"].get("eval_method") or "binary"
|
336 |
+
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
337 |
+
valid_doc_scores = []
|
338 |
+
|
339 |
+
for doc in sample(state["docs"], min(25, len(state["docs"]))):
|
340 |
+
score = eval_doc(
|
341 |
+
doc=format_doc(doc),
|
342 |
+
query=state["query"],
|
343 |
+
method=eval_method,
|
344 |
+
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
345 |
+
eval_model = config["configurable"].get("eval_model") or "deepseek-r1-distill-llama-70b"
|
346 |
+
)
|
347 |
+
if score:
|
348 |
+
valid_doc_scores.append((doc, score))
|
349 |
+
|
350 |
+
if eval_method == 'score':
|
351 |
+
# Get at most MAX_DOCS items with the highest score if score method was used
|
352 |
+
valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
|
353 |
+
valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
|
354 |
+
else:
|
355 |
+
# Get at mots MAX_DOCS items at random if binary method was used
|
356 |
+
shuffle(valid_doc_scores)
|
357 |
+
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
|
358 |
+
|
359 |
+
return {"valid_docs": valid_docs + (state["valid_docs"] or [])}
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
def build_data_retriever_graph(memory):
|
364 |
+
"""
|
365 |
+
Builds the data_retriever graph
|
366 |
+
"""
|
367 |
+
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
368 |
+
|
369 |
+
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
|
370 |
+
|
371 |
+
graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher)
|
372 |
+
graph_builder_doc_retriever.add_node("get_docs", get_docs)
|
373 |
+
graph_builder_doc_retriever.add_node("eval_docs", eval_docs)
|
374 |
+
|
375 |
+
|
376 |
+
graph_builder_doc_retriever.add_edge("__start__", "generate_cypher")
|
377 |
+
graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs")
|
378 |
+
graph_builder_doc_retriever.add_edge("get_docs", "eval_docs")
|
379 |
+
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
|
380 |
+
|
381 |
+
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
|
382 |
+
|
383 |
+
return graph_doc_retriever
|
384 |
+
|
385 |
+
def error_concept_groq(msg,concepts,groq,question):
|
386 |
+
try:
|
387 |
+
start = msg.find("Requested") + len("Requested ")
|
388 |
+
end = msg.find(",", start)
|
389 |
+
rate_limit = int(msg[start:end])
|
390 |
+
related_concepts = []
|
391 |
+
i = 0
|
392 |
+
start = 0
|
393 |
+
end = len(concepts) // (rate_limit // 5000 + (1 if rate_limit%4500 != 0 else 0))
|
394 |
+
while (i < rate_limit // 5000):
|
395 |
+
smaller_concepts = concepts[start:end]
|
396 |
+
start = end
|
397 |
+
end = end + len(concepts) // (rate_limit//5000 + (1 if rate_limit%4500 != 0 else 0))
|
398 |
+
res = groq.invoke({question[0] : question[1], "concepts" : '\n'.join(smaller_concepts)})
|
399 |
+
for r in res:
|
400 |
+
related_concepts.append(r)
|
401 |
+
i+=1
|
402 |
+
return related_concepts
|
403 |
+
except Exception as e:
|
404 |
+
if e.status_code == 419:
|
405 |
+
time.sleep(65)
|
406 |
+
error_concept_groq(msg,concepts,groq,question)
|
ki_gen/planner.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
from typing import Annotated
|
5 |
+
from typing_extensions import TypedDict
|
6 |
+
|
7 |
+
from langchain_groq import ChatGroq
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
+
from langchain_community.graphs import Neo4jGraph
|
11 |
+
|
12 |
+
from langgraph.graph import StateGraph
|
13 |
+
from langgraph.graph import add_messages
|
14 |
+
|
15 |
+
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
|
16 |
+
from ki_gen.data_retriever import build_data_retriever_graph
|
17 |
+
from ki_gen.data_processor import build_data_processor_graph
|
18 |
+
from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState
|
19 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
##########################################################################
|
24 |
+
###### NODES DEFINITION ######
|
25 |
+
##########################################################################
|
26 |
+
|
27 |
+
def validate_node(state: State):
|
28 |
+
"""
|
29 |
+
This node inserts the plan validation prompt.
|
30 |
+
"""
|
31 |
+
prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise.
|
32 |
+
If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct"."""
|
33 |
+
output = HumanMessage(content=prompt)
|
34 |
+
return {"messages" : [output]}
|
35 |
+
|
36 |
+
|
37 |
+
def error_chatbot_groq(error, model_name, query): # Pass model_name instead of llm_groq object
|
38 |
+
# Switch API key logic...
|
39 |
+
if os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key"):
|
40 |
+
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key2")
|
41 |
+
elif os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key2"):
|
42 |
+
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key3")
|
43 |
+
else:
|
44 |
+
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key")
|
45 |
+
|
46 |
+
# Re-initialize the model *after* switching the key
|
47 |
+
try:
|
48 |
+
# Use the model_name passed in
|
49 |
+
llm_groq_retry = ChatGroq(model=model_name)
|
50 |
+
# Pass the original query messages
|
51 |
+
return {"messages" : [llm_groq_retry.invoke(query)]}
|
52 |
+
except Exception as retry_error:
|
53 |
+
# Handle potential error during retry
|
54 |
+
print(f"Error during retry: {retry_error}")
|
55 |
+
# Decide what to return or raise here
|
56 |
+
return {"messages": [SystemMessage(content=f"Failed to process after retry: {retry_error}")]}
|
57 |
+
|
58 |
+
|
59 |
+
# Wrappers to call LLMs on the state messsages field
|
60 |
+
def chatbot_llama(state: State):
|
61 |
+
try:
|
62 |
+
llm_llama = ChatGroq(model="llama3-70b-8192")
|
63 |
+
return {"messages" : [llm_llama.invoke(state["messages"])]}
|
64 |
+
except Exception as error:
|
65 |
+
error_chatbot_groq(error,llm_llama,state["messages"])
|
66 |
+
def chatbot_mixtral(state: State):
|
67 |
+
print(state)
|
68 |
+
llm_mixtral = ChatGroq(model="deepseek-r1-distill-llama-70b")
|
69 |
+
print(llm_mixtral)
|
70 |
+
return {"messages" : [llm_mixtral.invoke(state["messages"])]}
|
71 |
+
# except Exception as error:
|
72 |
+
# error_chatbot_groq(error,llm_mixtral,state["messages"])
|
73 |
+
def chatbot_openai(state: State):
|
74 |
+
llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
75 |
+
return {"messages" : [llm_openai.invoke(state["messages"])]}
|
76 |
+
|
77 |
+
chatbots = {"gpt-4o" : chatbot_openai,
|
78 |
+
"deepseek-r1-distill-llama-70b" : chatbot_mixtral,
|
79 |
+
"llama3-70b-8192" : chatbot_llama
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
def parse_plan(state: State):
|
84 |
+
"""
|
85 |
+
This node parses the generated plan and writes in the 'store_plan' field of the state
|
86 |
+
"""
|
87 |
+
plan = state["messages"][-3].content
|
88 |
+
store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:]
|
89 |
+
try:
|
90 |
+
store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0]
|
91 |
+
except Exception as e:
|
92 |
+
print(f"Error while removing <END_OF_PLAN> : {e}")
|
93 |
+
|
94 |
+
return {"store_plan" : store_plan}
|
95 |
+
|
96 |
+
def detail_step(state: State, config: ConfigSchema):
|
97 |
+
"""
|
98 |
+
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
|
99 |
+
"""
|
100 |
+
print("test")
|
101 |
+
print(state)
|
102 |
+
|
103 |
+
if 'current_plan_step' in state.keys():
|
104 |
+
print("all good chief")
|
105 |
+
else:
|
106 |
+
state["current_plan_step"] = None
|
107 |
+
|
108 |
+
current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end
|
109 |
+
if config["configurable"].get("use_detailed_query"):
|
110 |
+
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
|
111 |
+
Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""")
|
112 |
+
query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm"))
|
113 |
+
return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query}
|
114 |
+
|
115 |
+
return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []}
|
116 |
+
|
117 |
+
def get_detailed_query(context : list, model : str = "deepseek-r1-distill-llama-70b"):
|
118 |
+
"""
|
119 |
+
Simple helper function for the detail_step node
|
120 |
+
"""
|
121 |
+
if model == 'gpt-4o':
|
122 |
+
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
123 |
+
else:
|
124 |
+
llm = ChatGroq(model=model)
|
125 |
+
return llm.invoke(context)
|
126 |
+
|
127 |
+
def concatenate_data(state: State):
|
128 |
+
"""
|
129 |
+
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
|
130 |
+
"""
|
131 |
+
prompt = f"""#########TECHNICAL INFORMATION ############
|
132 |
+
{str(state["valid_docs"])}
|
133 |
+
|
134 |
+
########END OF TECHNICAL INFORMATION#######
|
135 |
+
|
136 |
+
Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan :
|
137 |
+
{state['store_plan'][state['current_plan_step']]}
|
138 |
+
"""
|
139 |
+
|
140 |
+
return {"messages": [HumanMessage(content=prompt)]}
|
141 |
+
|
142 |
+
|
143 |
+
def human_validation(state: HumanValidationState) -> HumanValidationState:
|
144 |
+
"""
|
145 |
+
Dummy node to interrupt before
|
146 |
+
"""
|
147 |
+
return {'process_steps' : []}
|
148 |
+
|
149 |
+
def generate_ki(state: State):
|
150 |
+
"""
|
151 |
+
This node inserts the prompt to begin Key Issues generation
|
152 |
+
"""
|
153 |
+
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}")
|
154 |
+
|
155 |
+
prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
|
156 |
+
{state['store_plan'][state['current_plan_step'] + 1]}"""
|
157 |
+
|
158 |
+
return {"messages" : [HumanMessage(content=prompt)]}
|
159 |
+
|
160 |
+
def detail_ki(state: State):
|
161 |
+
"""
|
162 |
+
This node inserts the last prompt to detail the generated Key Issues
|
163 |
+
"""
|
164 |
+
prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
|
165 |
+
{state['store_plan'][state['current_plan_step'] + 2]}"""
|
166 |
+
|
167 |
+
return {"messages" : [HumanMessage(content=prompt)]}
|
168 |
+
|
169 |
+
##########################################################################
|
170 |
+
###### CONDITIONAL EDGE FUNCTIONS ######
|
171 |
+
##########################################################################
|
172 |
+
|
173 |
+
def validate_plan(state: State):
|
174 |
+
"""
|
175 |
+
Whether to regenerate the plan or to parse it
|
176 |
+
"""
|
177 |
+
if "messages" in state and "My plan is correct" in state["messages"][-1].content:
|
178 |
+
return "parse"
|
179 |
+
return "validate"
|
180 |
+
|
181 |
+
def next_plan_step(state: State, config: ConfigSchema):
|
182 |
+
"""
|
183 |
+
Proceed to next plan step (either generate KI or retrieve more data)
|
184 |
+
"""
|
185 |
+
if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"):
|
186 |
+
return "generate_key_issues"
|
187 |
+
if state["current_plan_step"] == len(state["store_plan"]) - 1:
|
188 |
+
return "generate_key_issues"
|
189 |
+
else:
|
190 |
+
return "detail_step"
|
191 |
+
|
192 |
+
def detail_or_data_retriever(state: State, config: ConfigSchema):
|
193 |
+
"""
|
194 |
+
Detail the query to use for data retrieval or not
|
195 |
+
"""
|
196 |
+
if config["configurable"].get("use_detailed_query"):
|
197 |
+
return "chatbot_detail"
|
198 |
+
else:
|
199 |
+
return "data_retriever"
|
200 |
+
|
201 |
+
def retrieve_or_process(state: State):
|
202 |
+
"""
|
203 |
+
Process the retrieved docs or keep retrieving
|
204 |
+
"""
|
205 |
+
if state['human_validated']:
|
206 |
+
return "process"
|
207 |
+
return "retrieve"
|
208 |
+
# while True:
|
209 |
+
# user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ")
|
210 |
+
# if user_input.lower() == "y":
|
211 |
+
# return "retrieve"
|
212 |
+
# if not user_input or user_input.lower() == "n":
|
213 |
+
# return "process"
|
214 |
+
# print("Please answer with 'y' or 'n'.\n")
|
215 |
+
|
216 |
+
|
217 |
+
def build_planner_graph(memory, config):
|
218 |
+
"""
|
219 |
+
Builds the planner graph
|
220 |
+
"""
|
221 |
+
graph_builder = StateGraph(State)
|
222 |
+
|
223 |
+
graph_doc_retriever = build_data_retriever_graph(memory)
|
224 |
+
graph_doc_processor = build_data_processor_graph(memory)
|
225 |
+
graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]])
|
226 |
+
graph_builder.add_node("validate", validate_node)
|
227 |
+
graph_builder.add_node("chatbot_detail", chatbot_llama)
|
228 |
+
graph_builder.add_node("parse", parse_plan)
|
229 |
+
graph_builder.add_node("detail_step", detail_step)
|
230 |
+
graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState)
|
231 |
+
graph_builder.add_node("human_validation", human_validation)
|
232 |
+
graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState)
|
233 |
+
graph_builder.add_node("concatenate_data", concatenate_data)
|
234 |
+
graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]])
|
235 |
+
graph_builder.add_node("generate_ki", generate_ki)
|
236 |
+
graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]])
|
237 |
+
graph_builder.add_node("detail_ki", detail_ki)
|
238 |
+
graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]])
|
239 |
+
|
240 |
+
graph_builder.add_edge("validate", "chatbot_planner")
|
241 |
+
graph_builder.add_edge("parse", "detail_step")
|
242 |
+
|
243 |
+
|
244 |
+
# graph_builder.add_edge("detail_step", "chatbot2")
|
245 |
+
graph_builder.add_edge("chatbot_detail", "data_retriever")
|
246 |
+
graph_builder.add_edge("data_retriever", "human_validation")
|
247 |
+
|
248 |
+
|
249 |
+
graph_builder.add_edge("data_processor", "concatenate_data")
|
250 |
+
graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
|
251 |
+
graph_builder.add_edge("generate_ki", "chatbot_ki")
|
252 |
+
graph_builder.add_edge("chatbot_ki", "detail_ki")
|
253 |
+
graph_builder.add_edge("detail_ki", "chatbot_final")
|
254 |
+
graph_builder.add_edge("chatbot_final", "__end__")
|
255 |
+
|
256 |
+
graph_builder.add_conditional_edges(
|
257 |
+
"detail_step",
|
258 |
+
detail_or_data_retriever,
|
259 |
+
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
|
260 |
+
)
|
261 |
+
graph_builder.add_conditional_edges(
|
262 |
+
"human_validation",
|
263 |
+
retrieve_or_process,
|
264 |
+
{"retrieve" : "data_retriever", "process" : "data_processor"}
|
265 |
+
)
|
266 |
+
graph_builder.add_conditional_edges(
|
267 |
+
"chatbot_planner",
|
268 |
+
validate_plan,
|
269 |
+
{"parse" : "parse", "validate": "validate"}
|
270 |
+
)
|
271 |
+
graph_builder.add_conditional_edges(
|
272 |
+
"chatbot_exec_step",
|
273 |
+
next_plan_step,
|
274 |
+
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
|
275 |
+
)
|
276 |
+
|
277 |
+
graph_builder.set_entry_point("chatbot_planner")
|
278 |
+
graph = graph_builder.compile(
|
279 |
+
checkpointer=memory,
|
280 |
+
interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"],
|
281 |
+
)
|
282 |
+
return graph
|
ki_gen/prompts.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts.prompt import PromptTemplate
|
2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
3 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
4 |
+
from ki_gen.utils import ConfigSchema
|
5 |
+
|
6 |
+
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
|
7 |
+
Instructions:
|
8 |
+
Use only the provided relationship types and properties in the schema.
|
9 |
+
Do not use any other relationship types or properties that are not provided.
|
10 |
+
Schema:
|
11 |
+
{schema}
|
12 |
+
|
13 |
+
|
14 |
+
Concepts:
|
15 |
+
{concepts}
|
16 |
+
|
17 |
+
|
18 |
+
Concept names can ONLY be selected from the above list
|
19 |
+
|
20 |
+
Note: Do not include any explanations or apologies in your responses.
|
21 |
+
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
22 |
+
Do not include any text except the generated Cypher statement.
|
23 |
+
|
24 |
+
The question is:
|
25 |
+
{question}"""
|
26 |
+
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
27 |
+
input_variables=["schema", "question", "concepts"], template=CYPHER_GENERATION_TEMPLATE
|
28 |
+
)
|
29 |
+
|
30 |
+
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
|
31 |
+
The information part contains the provided information that you must use to construct an answer.
|
32 |
+
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
33 |
+
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
|
34 |
+
Here is an example:
|
35 |
+
|
36 |
+
Question: Which managers own Neo4j stocks?
|
37 |
+
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
|
38 |
+
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
|
39 |
+
|
40 |
+
Follow this example when generating answers.
|
41 |
+
If the provided information is empty, say that you don't know the answer.
|
42 |
+
Information:
|
43 |
+
{context}
|
44 |
+
|
45 |
+
Question: {question}
|
46 |
+
Helpful Answer:"""
|
47 |
+
CYPHER_QA_PROMPT = PromptTemplate(
|
48 |
+
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
|
49 |
+
)
|
50 |
+
|
51 |
+
PLAN_GEN_PROMPT = """System : You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement
|
52 |
+
|
53 |
+
System : Let's first understand the problem and devise a plan to solve the problem.
|
54 |
+
Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
|
55 |
+
Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
|
56 |
+
At the end of your plan, say '<END_OF_PLAN>'"""
|
57 |
+
|
58 |
+
PLAN_MODIFICATION_PROMPT = """You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement.
|
59 |
+
To achieve this goal we are going to follow this generic plan :
|
60 |
+
|
61 |
+
###PLAN TEMPLATE###
|
62 |
+
|
63 |
+
Plan:
|
64 |
+
|
65 |
+
1. **Understanding the Problem**: Gather information from existing specifications and standards to thoroughly understand the technical requirement. This should help you understand the key aspects of the problem.
|
66 |
+
2. **Gather information about latest innovations** : Gather information about the latest innovations related to the problem by looking at the most relevant research papers and list the sources.
|
67 |
+
3. **Identifying NEW and INNOVATIVE Key Issues**: Based on the understanding of the problem, identify new and innovative key issues that could occur while trying to fulfill this requirement. Descripbe them in simple technical english. These key issues should be relevant, significant, and not yet addressed by existing solutions.
|
68 |
+
4. **Develop Detailed Descriptions for Each Key Issue**: For each identified key issue, provide a detailed description in simple technical english, including the specific challenges and areas requiring further study.
|
69 |
+
<END_OF_PLAN>
|
70 |
+
|
71 |
+
###END OF PLAN TEMPLATE###
|
72 |
+
|
73 |
+
Let's and devise a plan to solve the problem by adapting the PLAN TEMPLATE.
|
74 |
+
Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
|
75 |
+
Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
|
76 |
+
At the end of your plan, say '<END_OF_PLAN>' """
|
77 |
+
|
78 |
+
PLAN_GEN_PROMPT = PLAN_MODIFICATION_PROMPT
|
79 |
+
|
80 |
+
CONCEPT_SELECTION_TEMPLATE = """Task: Select the most relevant topic to the user question
|
81 |
+
Instructions:
|
82 |
+
Select the most relevant Concept to the user's question.
|
83 |
+
Concepts can ONLY be selected from the list below.
|
84 |
+
|
85 |
+
Concepts:
|
86 |
+
{concepts}
|
87 |
+
|
88 |
+
Note: Do not include any explanations or apologies in your responses.
|
89 |
+
Do not include any text except the selected concept.
|
90 |
+
|
91 |
+
The question is:
|
92 |
+
{question}"""
|
93 |
+
CONCEPT_SELECTION_PROMPT = PromptTemplate(
|
94 |
+
input_variables=["concepts", "question"], template=CONCEPT_SELECTION_TEMPLATE
|
95 |
+
)
|
96 |
+
|
97 |
+
RELEVANT_CONCEPTS_TEMPLATE = """
|
98 |
+
## CONCEPTS ##
|
99 |
+
{concepts}
|
100 |
+
## END OF CONCEPTS ##
|
101 |
+
|
102 |
+
Select the 20 most relevant concepts to the user query.
|
103 |
+
Output your answer as a numbered list preceeded with the header 'Concepts:'.
|
104 |
+
|
105 |
+
User query :
|
106 |
+
{user_query}
|
107 |
+
"""
|
108 |
+
RELEVANT_CONCEPTS_PROMPT = ChatPromptTemplate.from_messages([
|
109 |
+
("human", RELEVANT_CONCEPTS_TEMPLATE)
|
110 |
+
])
|
111 |
+
|
112 |
+
SUMMARIZER_TEMPLATE = """You are a 3GPP standardization expert.
|
113 |
+
Summarize the provided document in simple technical English for other experts in the field.
|
114 |
+
|
115 |
+
Document:
|
116 |
+
{document}"""
|
117 |
+
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
|
118 |
+
("system", SUMMARIZER_TEMPLATE)
|
119 |
+
])
|
120 |
+
|
121 |
+
|
122 |
+
BINARY_GRADER_TEMPLATE = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
123 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
|
124 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
125 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
126 |
+
BINARY_GRADER_PROMPT = ChatPromptTemplate.from_messages(
|
127 |
+
[
|
128 |
+
("system", BINARY_GRADER_TEMPLATE),
|
129 |
+
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
130 |
+
]
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
SCORE_GRADER_TEMPLATE = """Grasp and understand both the query and the document before score generation.
|
135 |
+
Then, based on your understanding and analysis quantify the relevance between the document and the query.
|
136 |
+
Give the rationale before answering.
|
137 |
+
Ouput your answer as a score ranging between 0 (irrelevant document) and 1 (completely relevant document)"""
|
138 |
+
|
139 |
+
SCORE_GRADER_PROMPT = ChatPromptTemplate.from_messages(
|
140 |
+
[
|
141 |
+
("system", SCORE_GRADER_TEMPLATE),
|
142 |
+
("human", "Passage: \n\n {document} \n\n User query: {query}")
|
143 |
+
]
|
144 |
+
)
|
145 |
+
|
146 |
+
def get_initial_prompt(config: ConfigSchema, user_query : str):
|
147 |
+
if config["configurable"].get("plan_method") == "generation":
|
148 |
+
prompt = PLAN_GEN_PROMPT
|
149 |
+
elif config["configurable"].get("plan_method") == "modification":
|
150 |
+
prompt = PLAN_MODIFICATION_PROMPT
|
151 |
+
else:
|
152 |
+
raise ValueError("Incorrect plan_method, should be 'generation' or 'modification'")
|
153 |
+
|
154 |
+
user_input = user_query or input("User :")
|
155 |
+
return {"messages" : [SystemMessage(content=prompt), HumanMessage(content=user_input)]}
|
ki_gen/utils.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import getpass
|
3 |
+
import html
|
4 |
+
|
5 |
+
|
6 |
+
from typing import Annotated, Union
|
7 |
+
from typing_extensions import TypedDict
|
8 |
+
|
9 |
+
from langchain_community.graphs import Neo4jGraph
|
10 |
+
from langchain_groq import ChatGroq
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
|
13 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
14 |
+
from langgraph.checkpoint.memory import MemorySaver
|
15 |
+
from langgraph.checkpoint import base
|
16 |
+
from langgraph.graph import add_messages
|
17 |
+
|
18 |
+
memory = MemorySaver()
|
19 |
+
|
20 |
+
def format_df(df):
|
21 |
+
"""
|
22 |
+
Used to display the generated plan in a nice format
|
23 |
+
Returns html code in a string
|
24 |
+
"""
|
25 |
+
def format_cell(cell):
|
26 |
+
if isinstance(cell, str):
|
27 |
+
# Encode special characters, but preserve line breaks
|
28 |
+
return html.escape(cell).replace('\n', '<br>')
|
29 |
+
return cell
|
30 |
+
# Convert the DataFrame to HTML with custom CSS
|
31 |
+
formatted_df = df.map(format_cell)
|
32 |
+
html_table = formatted_df.to_html(escape=False, index=False)
|
33 |
+
|
34 |
+
# Add custom CSS to allow multiple lines and scrolling in cells
|
35 |
+
css = """
|
36 |
+
<style>
|
37 |
+
table {
|
38 |
+
border-collapse: collapse;
|
39 |
+
width: 100%;
|
40 |
+
}
|
41 |
+
th, td {
|
42 |
+
border: 1px solid black;
|
43 |
+
padding: 8px;
|
44 |
+
text-align: left;
|
45 |
+
vertical-align: top;
|
46 |
+
white-space: pre-wrap;
|
47 |
+
max-width: 300px;
|
48 |
+
max-height: 100px;
|
49 |
+
overflow-y: auto;
|
50 |
+
}
|
51 |
+
th {
|
52 |
+
background-color: #f2f2f2;
|
53 |
+
}
|
54 |
+
</style>
|
55 |
+
"""
|
56 |
+
|
57 |
+
return css + html_table
|
58 |
+
|
59 |
+
def format_doc(doc: dict) -> str :
|
60 |
+
formatted_string = ""
|
61 |
+
for key in doc:
|
62 |
+
formatted_string += f"**{key}**: {doc[key]}\n"
|
63 |
+
return formatted_string
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def _set_env(var: str, value: str = None):
|
68 |
+
if not os.environ.get(var):
|
69 |
+
if value:
|
70 |
+
os.environ[var] = value
|
71 |
+
else:
|
72 |
+
os.environ[var] = getpass.getpass(f"{var}: ")
|
73 |
+
|
74 |
+
|
75 |
+
def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
|
76 |
+
"""
|
77 |
+
Initialize app with user api keys and sets up proxy settings
|
78 |
+
"""
|
79 |
+
_set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
|
80 |
+
_set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
|
81 |
+
_set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
|
82 |
+
os.environ["LANGSMITH_TRACING_V2"] = "true"
|
83 |
+
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
|
84 |
+
|
85 |
+
|
86 |
+
def clear_memory(memory, thread_id: str = "") -> None:
|
87 |
+
"""
|
88 |
+
Clears checkpointer state for a given thread_id, broken for now
|
89 |
+
TODO : fix this
|
90 |
+
"""
|
91 |
+
memory = MemorySaver()
|
92 |
+
|
93 |
+
#checkpoint = base.empty_checkpoint()
|
94 |
+
#memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
|
95 |
+
|
96 |
+
def get_model(model : str = "deepseek-r1-distill-llama-70b"):
|
97 |
+
"""
|
98 |
+
Wrapper to return the correct llm object depending on the 'model' param
|
99 |
+
"""
|
100 |
+
if model == "gpt-4o":
|
101 |
+
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
102 |
+
else:
|
103 |
+
llm = ChatGroq(model=model)
|
104 |
+
return llm
|
105 |
+
|
106 |
+
|
107 |
+
class ConfigSchema(TypedDict):
|
108 |
+
graph: Neo4jGraph
|
109 |
+
plan_method: str
|
110 |
+
use_detailed_query: bool
|
111 |
+
|
112 |
+
class State(TypedDict):
|
113 |
+
messages : Annotated[list, add_messages]
|
114 |
+
store_plan : list[str]
|
115 |
+
current_plan_step : int
|
116 |
+
valid_docs : list[str]
|
117 |
+
|
118 |
+
class DocRetrieverState(TypedDict):
|
119 |
+
messages: Annotated[list, add_messages]
|
120 |
+
query: str
|
121 |
+
docs: list[dict]
|
122 |
+
cyphers: list[str]
|
123 |
+
current_plan_step : int
|
124 |
+
valid_docs: list[Union[str, dict]]
|
125 |
+
|
126 |
+
class HumanValidationState(TypedDict):
|
127 |
+
human_validated : bool
|
128 |
+
process_steps : list[str]
|
129 |
+
|
130 |
+
def update_doc_history(left : list | None, right : list | None) -> list:
|
131 |
+
"""
|
132 |
+
Reducer for the 'docs_in_processing' field.
|
133 |
+
Doesn't work currently because of bad handlinf of duplicates
|
134 |
+
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
|
135 |
+
"""
|
136 |
+
if not left:
|
137 |
+
# This shouldn't happen
|
138 |
+
left = [[]]
|
139 |
+
if not right:
|
140 |
+
right = []
|
141 |
+
|
142 |
+
for i in range(len(right)):
|
143 |
+
left[i].append(right[i])
|
144 |
+
return left
|
145 |
+
|
146 |
+
|
147 |
+
class DocProcessorState(TypedDict):
|
148 |
+
valid_docs : list[Union[str, dict]]
|
149 |
+
docs_in_processing : list
|
150 |
+
process_steps : list[Union[str,dict]]
|
151 |
+
current_process_step : int
|
152 |
+
|