adrienbrdne commited on
Commit
70d06c8
·
verified ·
1 Parent(s): f8ac349

Upload 5 files

Browse files
ki_gen/data_processor.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_groq import ChatGroq
8
+ from langgraph.graph import StateGraph
9
+ from llmlingua import PromptCompressor
10
+
11
+ from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
12
+ from langgraph.checkpoint.sqlite import SqliteSaver
13
+
14
+
15
+
16
+
17
+ # compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
18
+
19
+ ## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
20
+ ## Before that, you need to pip install optimum auto-gptq
21
+ # llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
22
+
23
+
24
+
25
+ # Requires ~2GB of RAM
26
+ def get_llm_lingua(compress_method:str = "llm_lingua2"):
27
+
28
+ # Requires ~2GB memory
29
+ if compress_method == "llm_lingua2":
30
+ llm_lingua2 = PromptCompressor(
31
+ model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
32
+ use_llmlingua2=True,
33
+ device_map="cpu"
34
+ )
35
+ return llm_lingua2
36
+
37
+ # Requires ~8GB memory
38
+ elif compress_method == "llm_lingua":
39
+ llm_lingua = PromptCompressor(
40
+ model_name="microsoft/phi-2",
41
+ device_map="cpu"
42
+ )
43
+ return llm_lingua
44
+ raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
45
+
46
+
47
+
48
+ def compress(state: DocProcessorState, config: ConfigSchema):
49
+ """
50
+ This node compresses last processing result for each doc using llm_lingua
51
+ """
52
+ doc_process_histories = state["docs_in_processing"]
53
+ llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
54
+ for doc_process_history in doc_process_histories:
55
+ doc_process_history.append(llm_lingua.compress_prompt(
56
+ doc = str(doc_process_history[-1]),
57
+ rate=config["configurable"].get("compress_rate") or 0.33,
58
+ force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
59
+ )["compressed_prompt"]
60
+ )
61
+
62
+ return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
63
+
64
+ def summarize_docs(state: DocProcessorState, config: ConfigSchema):
65
+ """
66
+ This node summarizes all docs in state["valid_docs"]
67
+ """
68
+
69
+ prompt = """You are a 3GPP standardization expert.
70
+ Summarize the provided document in simple technical English for other experts in the field.
71
+
72
+ Document:
73
+ {document}"""
74
+ sysmsg = ChatPromptTemplate.from_messages([
75
+ ("system", prompt)
76
+ ])
77
+ model = config["configurable"].get("summarize_model") or "deepseek-r1-distill-llama-70b"
78
+ doc_process_histories = state["docs_in_processing"]
79
+ if model == "gpt-4o":
80
+ llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
81
+ else:
82
+ llm_summarize = ChatGroq(model=model)
83
+ summarize_chain = sysmsg | llm_summarize | StrOutputParser()
84
+
85
+ for doc_process_history in doc_process_histories:
86
+ doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
87
+
88
+ return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
89
+
90
+ def custom_process(state: DocProcessorState):
91
+ """
92
+ Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
93
+ processing_model : the LLM which will perform the processing
94
+ context : the previous processing results to send as context to the LLM
95
+ user_prompt : the prompt/task which will be appended to the context before sending to the LLM
96
+ """
97
+
98
+ processing_params = state["process_steps"][state["current_process_step"]]
99
+ model = processing_params.get("processing_model") or "deepseek-r1-distill-llama-70b"
100
+ user_prompt = processing_params["prompt"]
101
+ context = processing_params.get("context") or [0]
102
+ doc_process_histories = state["docs_in_processing"]
103
+ if not isinstance(context, list):
104
+ context = [context]
105
+
106
+ processing_chain = get_model(model=model) | StrOutputParser()
107
+
108
+ for doc_process_history in doc_process_histories:
109
+ context_str = ""
110
+ for i, context_element in enumerate(context):
111
+ context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
112
+ doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
113
+
114
+ return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
115
+
116
+ def final(state: DocProcessorState):
117
+ """
118
+ A node to store the final results of processing in the 'valid_docs' field
119
+ """
120
+ return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
121
+
122
+ # TODO : remove this node and use conditional entry point instead
123
+ def get_process_steps(state: DocProcessorState, config: ConfigSchema):
124
+ """
125
+ Dummy node
126
+ """
127
+ # if not process_steps:
128
+ # process_steps = eval(input("Enter processing steps: "))
129
+ return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
130
+
131
+
132
+ def next_processor_step(state: DocProcessorState):
133
+ """
134
+ Conditional edge function to go to next processing step
135
+ """
136
+ process_steps = state["process_steps"]
137
+ if state["current_process_step"] < len(process_steps):
138
+ step = process_steps[state["current_process_step"]]
139
+ if isinstance(step, dict):
140
+ step = "custom"
141
+ else:
142
+ step = "final"
143
+
144
+ return step
145
+
146
+
147
+ def build_data_processor_graph(memory):
148
+ """
149
+ Builds the data processor graph
150
+ """
151
+ #with SqliteSaver.from_conn_string(":memory:") as memory :
152
+
153
+ graph_builder_doc_processor = StateGraph(DocProcessorState)
154
+
155
+ graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
156
+ graph_builder_doc_processor.add_node("summarize", summarize_docs)
157
+ graph_builder_doc_processor.add_node("compress", compress)
158
+ graph_builder_doc_processor.add_node("custom", custom_process)
159
+ graph_builder_doc_processor.add_node("final", final)
160
+
161
+ graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
162
+ graph_builder_doc_processor.add_conditional_edges(
163
+ "get_process_steps",
164
+ next_processor_step,
165
+ {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
166
+ )
167
+ graph_builder_doc_processor.add_conditional_edges(
168
+ "summarize",
169
+ next_processor_step,
170
+ {"compress" : "compress", "final": "final", "custom" : "custom"}
171
+ )
172
+ graph_builder_doc_processor.add_conditional_edges(
173
+ "compress",
174
+ next_processor_step,
175
+ {"summarize" : "summarize", "final": "final", "custom" : "custom"}
176
+ )
177
+ graph_builder_doc_processor.add_conditional_edges(
178
+ "custom",
179
+ next_processor_step,
180
+ {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
181
+ )
182
+ graph_builder_doc_processor.add_edge("final", "__end__")
183
+
184
+ graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
185
+ return graph_doc_processor
ki_gen/data_retriever.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import re
5
+ import time
6
+ from random import shuffle, sample
7
+ from langgraph.checkpoint.sqlite import SqliteSaver
8
+
9
+ from langchain_groq import ChatGroq
10
+ from langchain_openai import ChatOpenAI
11
+ from langchain_core.messages import HumanMessage
12
+ from langchain_community.graphs import Neo4jGraph
13
+ from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.pydantic_v1 import Field
17
+ from pydantic import BaseModel
18
+ from langchain_groq import ChatGroq
19
+
20
+ from langgraph.graph import StateGraph
21
+
22
+ from llmlingua import PromptCompressor
23
+
24
+ from ki_gen.prompts import (
25
+ CYPHER_GENERATION_PROMPT,
26
+ CONCEPT_SELECTION_PROMPT,
27
+ BINARY_GRADER_PROMPT,
28
+ SCORE_GRADER_PROMPT,
29
+ RELEVANT_CONCEPTS_PROMPT,
30
+ )
31
+ from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
32
+
33
+
34
+
35
+
36
+ def extract_cypher(text: str) -> str:
37
+ """Extract Cypher code from a text.
38
+
39
+ Args:
40
+ text: Text to extract Cypher code from.
41
+
42
+ Returns:
43
+ Cypher code extracted from the text.
44
+ """
45
+ # The pattern to find Cypher code enclosed in triple backticks
46
+ pattern_1 = r"```cypher\n(.*?)```"
47
+ pattern_2 = r"```\n(.*?)```"
48
+
49
+ # Find all matches in the input text
50
+ matches_1 = re.findall(pattern_1, text, re.DOTALL)
51
+ matches_2 = re.findall(pattern_2, text, re.DOTALL)
52
+ return [
53
+ matches_1[0] if matches_1 else text,
54
+ matches_2[0] if matches_2 else text,
55
+ text
56
+ ]
57
+
58
+ def get_cypher_gen_chain(model: str = "deepseek-r1-distill-llama-70b"):
59
+ """
60
+ Returns cypher gen chain using specified model for generation
61
+ This is used when the 'auto' cypher generation method has been configured
62
+ """
63
+
64
+ if model=="openai":
65
+ llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
66
+ else:
67
+ llm_cypher_gen = ChatGroq(model = "deepseek-r1-distill-llama-70b")
68
+ cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
69
+ return cypher_gen_chain
70
+
71
+ def get_concept_selection_chain(model: str = "deepseek-r1-distill-llama-70b"):
72
+ """
73
+ Returns a chain to select the most relevant topic using specified model for generation.
74
+ This is used when the 'guided' cypher generation method has been configured
75
+ """
76
+
77
+ if model == "openai":
78
+ llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
79
+ else:
80
+ llm_topic_selection = ChatGroq(model="deepseek-r1-distill-llama-70b")
81
+ print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
82
+ topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
83
+ return topic_selection_chain
84
+
85
+ def get_concepts(graph: Neo4jGraph):
86
+ concept_cypher = "MATCH (c:Concept) return c"
87
+ if isinstance(graph, Neo4jGraph):
88
+ concepts = graph.query(concept_cypher)
89
+ else:
90
+ user_input = input("Topics : ")
91
+ concepts = eval(user_input)
92
+
93
+ concepts_name = [concept['c']['name'] for concept in concepts]
94
+ return concepts_name
95
+
96
+ def get_related_concepts(graph: Neo4jGraph, question: str):
97
+ concepts = get_concepts(graph)
98
+ llm = get_model()
99
+ print(f"this is the llm variable : {llm}")
100
+ def parse_answer(llm_answer : str):
101
+ try:
102
+ print(f"This the llm_answer : {llm_answer}")
103
+ return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
104
+ except:
105
+ return "No concept"
106
+ related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
107
+
108
+ print(f"This is the question of the user : {question}")
109
+ print(f"This is the concepts of the user : {concepts}")
110
+
111
+
112
+ #groq.APIStatusError: Error code: 413 - {'error': {'message': 'Request too large for model `deepseek-r1-distill-llama-70b` in organization `org_01j6xywkndffv96m3wgh81jm49` on tokens per minute
113
+ # (TPM): Limit 5000, Requested 17099, please reduce your message size and try again. Visit https://console.groq.com/docs/rate-limits for more information.',
114
+ # 'type': 'tokens', 'code': 'rate_limit_exceeded'}}
115
+
116
+ try:
117
+ related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
118
+ print(f"related_concepts_raw : {related_concepts_raw}")
119
+ except Exception as e:
120
+ if e.status_code == 413:
121
+ msg = e.body["error"]["message"]
122
+ print(f"question is : {question}")
123
+ print(type(question))
124
+ error_question = ["user_query", question]
125
+ related_concepts_raw = error_concept_groq(msg,concepts,related_concepts_chain,error_question)
126
+ pass
127
+
128
+ # We clean up the list we received from the LLM in case there were some hallucinations
129
+ related_concepts_cleaned = []
130
+ for related_concept in related_concepts_raw:
131
+ # If the concept returned from the LLM is in the list we keep it
132
+ if related_concept in concepts:
133
+ related_concepts_cleaned.append(related_concept)
134
+ else:
135
+ # The LLM sometimes only forgets a few words from the concept name
136
+ # We check if the generated concept is a substring of an existing one and if it is the case add it to the list
137
+ for concept in concepts:
138
+ if related_concept in concept:
139
+ related_concepts_cleaned.append(concept)
140
+ break
141
+
142
+ # TODO : Add concepts found via similarity search
143
+ return related_concepts_cleaned
144
+
145
+ def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
146
+ concept_string = ""
147
+ for concept in concept_list:
148
+ concept_description_query = f"""
149
+ MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
150
+ """
151
+ concept_description = graph.query(concept_description_query)[0]['c.description']
152
+ concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
153
+ return concept_string
154
+
155
+ def get_global_concepts(graph: Neo4jGraph):
156
+ concept_cypher = "MATCH (gc:GlobalConcept) return gc"
157
+ if isinstance(graph, Neo4jGraph):
158
+ concepts = graph.query(concept_cypher)
159
+ else:
160
+ user_input = input("Topics : ")
161
+ concepts = eval(user_input)
162
+
163
+ concepts_name = [concept['gc']['name'] for concept in concepts]
164
+ return concepts_name
165
+
166
+ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
167
+ """
168
+ The node where the cypher is generated
169
+ """
170
+ graph = config["configurable"].get("graph")
171
+ question = state['query']
172
+ related_concepts = get_related_concepts(graph, question)
173
+ cyphers = []
174
+
175
+ if config["configurable"].get("cypher_gen_method") == 'auto':
176
+ cypher_gen_chain = get_cypher_gen_chain()
177
+ cyphers = cypher_gen_chain.invoke({
178
+ "schema": graph.schema,
179
+ "question": question,
180
+ "concepts": related_concepts
181
+ })
182
+
183
+ try :
184
+
185
+ if config["configurable"].get("cypher_gen_method") == 'guided':
186
+ concept_selection_chain = get_concept_selection_chain()
187
+ print(f"Concept selection chain is : {concept_selection_chain}")
188
+ selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
189
+ print(f"Selected topic are : {selected_topic}")
190
+
191
+ except Exception as e:
192
+ error_question = ["question", question]
193
+ selected_topic = error_concept_groq(e.body["error"]["message"],get_concepts(graph),concept_selection_chain,error_question)
194
+ pass
195
+
196
+ if config["configurable"].get("cypher_gen_method") == 'guided':
197
+ cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
198
+ print(f"Cyphers are : {cyphers}")
199
+
200
+ if config["configurable"].get("validate_cypher"):
201
+ corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
202
+ cypher_corrector = CypherQueryCorrector(corrector_schema)
203
+ cyphers = [cypher_corrector(cypher) for cypher in cyphers]
204
+
205
+ return {"cyphers" : cyphers}
206
+
207
+ def generate_cypher_from_topic(selected_concept: str, plan_step: int):
208
+ """
209
+ Helper function used when the 'guided' cypher generation method has been configured
210
+ """
211
+
212
+ print(f"L.176 PLAN STEP : {plan_step}")
213
+ cypher_el = "(n) return n.title, n.description"
214
+ match plan_step:
215
+ case 0:
216
+ cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description"
217
+ case 1:
218
+ cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
219
+ case 2:
220
+ cypher_el = "(ki:KeyIssue) RETURN ki.description"
221
+ return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
222
+
223
+ def get_docs(state:DocRetrieverState, config:ConfigSchema):
224
+ """
225
+ This node retrieves docs from the graph using the generated cypher
226
+ """
227
+ graph = config["configurable"].get("graph")
228
+ output = []
229
+ if graph is not None:
230
+ for cypher in state["cyphers"]:
231
+ try:
232
+ output = graph.query(cypher)
233
+ break
234
+ except Exception as e:
235
+ print("Failed to retrieve docs : {e}")
236
+
237
+ # Clean up the docs we received as there may be duplicates depending on the cypher query
238
+ all_docs = []
239
+ for doc in output:
240
+ unwinded_doc = {}
241
+ for key in doc:
242
+ if isinstance(doc[key], dict):
243
+ all_docs.append(doc[key])
244
+ else:
245
+ unwinded_doc.update({key: doc[key]})
246
+ if unwinded_doc:
247
+ all_docs.append(unwinded_doc)
248
+
249
+
250
+ filtered_docs = []
251
+ for doc in all_docs:
252
+ if doc not in filtered_docs:
253
+ filtered_docs.append(doc)
254
+
255
+ return {"docs": filtered_docs}
256
+
257
+
258
+
259
+
260
+
261
+ # Data model
262
+ class GradeDocumentsBinary(BaseModel):
263
+ """Binary score for relevance check on retrieved documents."""
264
+
265
+ binary_score: str = Field(
266
+ description="Documents are relevant to the question, 'yes' or 'no'"
267
+ )
268
+
269
+ # LLM with function call
270
+ # llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
271
+
272
+ def get_binary_grader(model="deepseek-r1-distill-llama-70b"):
273
+ """
274
+ Returns a binary grader to evaluate relevance of documents using specified model for generation
275
+ This is used when the 'binary' evaluation method has been configured
276
+ """
277
+
278
+
279
+ if model == "gpt-4o":
280
+ llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
281
+ else:
282
+ llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
283
+ structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
284
+ retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
285
+ return retrieval_grader_binary
286
+
287
+
288
+ class GradeDocumentsScore(BaseModel):
289
+ """Score for relevance check on retrieved documents."""
290
+
291
+ score: float = Field(
292
+ description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
293
+ )
294
+
295
+ def get_score_grader(model="deepseek-r1-distill-llama-70b"):
296
+ """
297
+ Returns a score grader to evaluate relevance of documents using specified model for generation
298
+ This is used when the 'score' evaluation method has been configured
299
+ """
300
+ if model == "gpt-4o":
301
+ llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
302
+ else:
303
+ llm_grader_score = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature = 0)
304
+ structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
305
+ retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
306
+ return retrieval_grader_score
307
+
308
+
309
+ def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="deepseek-r1-distill-llama-70b"):
310
+ '''
311
+ doc : the document to evaluate
312
+ query : the query to which to doc shoud be relevant
313
+ method : "binary" or "score"
314
+ threshold : for "score" method, score above which a doc is considered relevant
315
+ '''
316
+ if method == "binary":
317
+ retrieval_grader_binary = get_binary_grader(model=eval_model)
318
+ return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
319
+ elif method == "score":
320
+ retrieval_grader_score = get_score_grader(model=eval_model)
321
+ score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
322
+ if score is not None:
323
+ return score if score >= threshold else 0
324
+ else:
325
+ # Couldn't parse score, marking document as relevant by default
326
+ return 1
327
+ else:
328
+ raise ValueError("Invalid method")
329
+
330
+ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
331
+ """
332
+ This node performs evaluation of the retrieved docs and
333
+ """
334
+
335
+ eval_method = config["configurable"].get("eval_method") or "binary"
336
+ MAX_DOCS = config["configurable"].get("max_docs") or 15
337
+ valid_doc_scores = []
338
+
339
+ for doc in sample(state["docs"], min(25, len(state["docs"]))):
340
+ score = eval_doc(
341
+ doc=format_doc(doc),
342
+ query=state["query"],
343
+ method=eval_method,
344
+ threshold=config["configurable"].get("eval_threshold") or 0.7,
345
+ eval_model = config["configurable"].get("eval_model") or "deepseek-r1-distill-llama-70b"
346
+ )
347
+ if score:
348
+ valid_doc_scores.append((doc, score))
349
+
350
+ if eval_method == 'score':
351
+ # Get at most MAX_DOCS items with the highest score if score method was used
352
+ valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
353
+ valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
354
+ else:
355
+ # Get at mots MAX_DOCS items at random if binary method was used
356
+ shuffle(valid_doc_scores)
357
+ valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
358
+
359
+ return {"valid_docs": valid_docs + (state["valid_docs"] or [])}
360
+
361
+
362
+
363
+ def build_data_retriever_graph(memory):
364
+ """
365
+ Builds the data_retriever graph
366
+ """
367
+ #with SqliteSaver.from_conn_string(":memory:") as memory :
368
+
369
+ graph_builder_doc_retriever = StateGraph(DocRetrieverState)
370
+
371
+ graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher)
372
+ graph_builder_doc_retriever.add_node("get_docs", get_docs)
373
+ graph_builder_doc_retriever.add_node("eval_docs", eval_docs)
374
+
375
+
376
+ graph_builder_doc_retriever.add_edge("__start__", "generate_cypher")
377
+ graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs")
378
+ graph_builder_doc_retriever.add_edge("get_docs", "eval_docs")
379
+ graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
380
+
381
+ graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
382
+
383
+ return graph_doc_retriever
384
+
385
+ def error_concept_groq(msg,concepts,groq,question):
386
+ try:
387
+ start = msg.find("Requested") + len("Requested ")
388
+ end = msg.find(",", start)
389
+ rate_limit = int(msg[start:end])
390
+ related_concepts = []
391
+ i = 0
392
+ start = 0
393
+ end = len(concepts) // (rate_limit // 5000 + (1 if rate_limit%4500 != 0 else 0))
394
+ while (i < rate_limit // 5000):
395
+ smaller_concepts = concepts[start:end]
396
+ start = end
397
+ end = end + len(concepts) // (rate_limit//5000 + (1 if rate_limit%4500 != 0 else 0))
398
+ res = groq.invoke({question[0] : question[1], "concepts" : '\n'.join(smaller_concepts)})
399
+ for r in res:
400
+ related_concepts.append(r)
401
+ i+=1
402
+ return related_concepts
403
+ except Exception as e:
404
+ if e.status_code == 419:
405
+ time.sleep(65)
406
+ error_concept_groq(msg,concepts,groq,question)
ki_gen/planner.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from typing import Annotated
5
+ from typing_extensions import TypedDict
6
+
7
+ from langchain_groq import ChatGroq
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
+ from langchain_community.graphs import Neo4jGraph
11
+
12
+ from langgraph.graph import StateGraph
13
+ from langgraph.graph import add_messages
14
+
15
+ from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
16
+ from ki_gen.data_retriever import build_data_retriever_graph
17
+ from ki_gen.data_processor import build_data_processor_graph
18
+ from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState
19
+ from langgraph.checkpoint.sqlite import SqliteSaver
20
+
21
+
22
+
23
+ ##########################################################################
24
+ ###### NODES DEFINITION ######
25
+ ##########################################################################
26
+
27
+ def validate_node(state: State):
28
+ """
29
+ This node inserts the plan validation prompt.
30
+ """
31
+ prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise.
32
+ If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct"."""
33
+ output = HumanMessage(content=prompt)
34
+ return {"messages" : [output]}
35
+
36
+
37
+ def error_chatbot_groq(error, model_name, query): # Pass model_name instead of llm_groq object
38
+ # Switch API key logic...
39
+ if os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key"):
40
+ os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key2")
41
+ elif os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key2"):
42
+ os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key3")
43
+ else:
44
+ os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key")
45
+
46
+ # Re-initialize the model *after* switching the key
47
+ try:
48
+ # Use the model_name passed in
49
+ llm_groq_retry = ChatGroq(model=model_name)
50
+ # Pass the original query messages
51
+ return {"messages" : [llm_groq_retry.invoke(query)]}
52
+ except Exception as retry_error:
53
+ # Handle potential error during retry
54
+ print(f"Error during retry: {retry_error}")
55
+ # Decide what to return or raise here
56
+ return {"messages": [SystemMessage(content=f"Failed to process after retry: {retry_error}")]}
57
+
58
+
59
+ # Wrappers to call LLMs on the state messsages field
60
+ def chatbot_llama(state: State):
61
+ try:
62
+ llm_llama = ChatGroq(model="llama3-70b-8192")
63
+ return {"messages" : [llm_llama.invoke(state["messages"])]}
64
+ except Exception as error:
65
+ error_chatbot_groq(error,llm_llama,state["messages"])
66
+ def chatbot_mixtral(state: State):
67
+ print(state)
68
+ llm_mixtral = ChatGroq(model="deepseek-r1-distill-llama-70b")
69
+ print(llm_mixtral)
70
+ return {"messages" : [llm_mixtral.invoke(state["messages"])]}
71
+ # except Exception as error:
72
+ # error_chatbot_groq(error,llm_mixtral,state["messages"])
73
+ def chatbot_openai(state: State):
74
+ llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
75
+ return {"messages" : [llm_openai.invoke(state["messages"])]}
76
+
77
+ chatbots = {"gpt-4o" : chatbot_openai,
78
+ "deepseek-r1-distill-llama-70b" : chatbot_mixtral,
79
+ "llama3-70b-8192" : chatbot_llama
80
+ }
81
+
82
+
83
+ def parse_plan(state: State):
84
+ """
85
+ This node parses the generated plan and writes in the 'store_plan' field of the state
86
+ """
87
+ plan = state["messages"][-3].content
88
+ store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:]
89
+ try:
90
+ store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0]
91
+ except Exception as e:
92
+ print(f"Error while removing <END_OF_PLAN> : {e}")
93
+
94
+ return {"store_plan" : store_plan}
95
+
96
+ def detail_step(state: State, config: ConfigSchema):
97
+ """
98
+ This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
99
+ """
100
+ print("test")
101
+ print(state)
102
+
103
+ if 'current_plan_step' in state.keys():
104
+ print("all good chief")
105
+ else:
106
+ state["current_plan_step"] = None
107
+
108
+ current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end
109
+ if config["configurable"].get("use_detailed_query"):
110
+ prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
111
+ Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""")
112
+ query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm"))
113
+ return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query}
114
+
115
+ return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []}
116
+
117
+ def get_detailed_query(context : list, model : str = "deepseek-r1-distill-llama-70b"):
118
+ """
119
+ Simple helper function for the detail_step node
120
+ """
121
+ if model == 'gpt-4o':
122
+ llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
123
+ else:
124
+ llm = ChatGroq(model=model)
125
+ return llm.invoke(context)
126
+
127
+ def concatenate_data(state: State):
128
+ """
129
+ This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
130
+ """
131
+ prompt = f"""#########TECHNICAL INFORMATION ############
132
+ {str(state["valid_docs"])}
133
+
134
+ ########END OF TECHNICAL INFORMATION#######
135
+
136
+ Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan :
137
+ {state['store_plan'][state['current_plan_step']]}
138
+ """
139
+
140
+ return {"messages": [HumanMessage(content=prompt)]}
141
+
142
+
143
+ def human_validation(state: HumanValidationState) -> HumanValidationState:
144
+ """
145
+ Dummy node to interrupt before
146
+ """
147
+ return {'process_steps' : []}
148
+
149
+ def generate_ki(state: State):
150
+ """
151
+ This node inserts the prompt to begin Key Issues generation
152
+ """
153
+ print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}")
154
+
155
+ prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
156
+ {state['store_plan'][state['current_plan_step'] + 1]}"""
157
+
158
+ return {"messages" : [HumanMessage(content=prompt)]}
159
+
160
+ def detail_ki(state: State):
161
+ """
162
+ This node inserts the last prompt to detail the generated Key Issues
163
+ """
164
+ prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
165
+ {state['store_plan'][state['current_plan_step'] + 2]}"""
166
+
167
+ return {"messages" : [HumanMessage(content=prompt)]}
168
+
169
+ ##########################################################################
170
+ ###### CONDITIONAL EDGE FUNCTIONS ######
171
+ ##########################################################################
172
+
173
+ def validate_plan(state: State):
174
+ """
175
+ Whether to regenerate the plan or to parse it
176
+ """
177
+ if "messages" in state and "My plan is correct" in state["messages"][-1].content:
178
+ return "parse"
179
+ return "validate"
180
+
181
+ def next_plan_step(state: State, config: ConfigSchema):
182
+ """
183
+ Proceed to next plan step (either generate KI or retrieve more data)
184
+ """
185
+ if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"):
186
+ return "generate_key_issues"
187
+ if state["current_plan_step"] == len(state["store_plan"]) - 1:
188
+ return "generate_key_issues"
189
+ else:
190
+ return "detail_step"
191
+
192
+ def detail_or_data_retriever(state: State, config: ConfigSchema):
193
+ """
194
+ Detail the query to use for data retrieval or not
195
+ """
196
+ if config["configurable"].get("use_detailed_query"):
197
+ return "chatbot_detail"
198
+ else:
199
+ return "data_retriever"
200
+
201
+ def retrieve_or_process(state: State):
202
+ """
203
+ Process the retrieved docs or keep retrieving
204
+ """
205
+ if state['human_validated']:
206
+ return "process"
207
+ return "retrieve"
208
+ # while True:
209
+ # user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ")
210
+ # if user_input.lower() == "y":
211
+ # return "retrieve"
212
+ # if not user_input or user_input.lower() == "n":
213
+ # return "process"
214
+ # print("Please answer with 'y' or 'n'.\n")
215
+
216
+
217
+ def build_planner_graph(memory, config):
218
+ """
219
+ Builds the planner graph
220
+ """
221
+ graph_builder = StateGraph(State)
222
+
223
+ graph_doc_retriever = build_data_retriever_graph(memory)
224
+ graph_doc_processor = build_data_processor_graph(memory)
225
+ graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]])
226
+ graph_builder.add_node("validate", validate_node)
227
+ graph_builder.add_node("chatbot_detail", chatbot_llama)
228
+ graph_builder.add_node("parse", parse_plan)
229
+ graph_builder.add_node("detail_step", detail_step)
230
+ graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState)
231
+ graph_builder.add_node("human_validation", human_validation)
232
+ graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState)
233
+ graph_builder.add_node("concatenate_data", concatenate_data)
234
+ graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]])
235
+ graph_builder.add_node("generate_ki", generate_ki)
236
+ graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]])
237
+ graph_builder.add_node("detail_ki", detail_ki)
238
+ graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]])
239
+
240
+ graph_builder.add_edge("validate", "chatbot_planner")
241
+ graph_builder.add_edge("parse", "detail_step")
242
+
243
+
244
+ # graph_builder.add_edge("detail_step", "chatbot2")
245
+ graph_builder.add_edge("chatbot_detail", "data_retriever")
246
+ graph_builder.add_edge("data_retriever", "human_validation")
247
+
248
+
249
+ graph_builder.add_edge("data_processor", "concatenate_data")
250
+ graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
251
+ graph_builder.add_edge("generate_ki", "chatbot_ki")
252
+ graph_builder.add_edge("chatbot_ki", "detail_ki")
253
+ graph_builder.add_edge("detail_ki", "chatbot_final")
254
+ graph_builder.add_edge("chatbot_final", "__end__")
255
+
256
+ graph_builder.add_conditional_edges(
257
+ "detail_step",
258
+ detail_or_data_retriever,
259
+ {"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
260
+ )
261
+ graph_builder.add_conditional_edges(
262
+ "human_validation",
263
+ retrieve_or_process,
264
+ {"retrieve" : "data_retriever", "process" : "data_processor"}
265
+ )
266
+ graph_builder.add_conditional_edges(
267
+ "chatbot_planner",
268
+ validate_plan,
269
+ {"parse" : "parse", "validate": "validate"}
270
+ )
271
+ graph_builder.add_conditional_edges(
272
+ "chatbot_exec_step",
273
+ next_plan_step,
274
+ {"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
275
+ )
276
+
277
+ graph_builder.set_entry_point("chatbot_planner")
278
+ graph = graph_builder.compile(
279
+ checkpointer=memory,
280
+ interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"],
281
+ )
282
+ return graph
ki_gen/prompts.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts.prompt import PromptTemplate
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.messages import SystemMessage, HumanMessage
4
+ from ki_gen.utils import ConfigSchema
5
+
6
+ CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
7
+ Instructions:
8
+ Use only the provided relationship types and properties in the schema.
9
+ Do not use any other relationship types or properties that are not provided.
10
+ Schema:
11
+ {schema}
12
+
13
+
14
+ Concepts:
15
+ {concepts}
16
+
17
+
18
+ Concept names can ONLY be selected from the above list
19
+
20
+ Note: Do not include any explanations or apologies in your responses.
21
+ Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
22
+ Do not include any text except the generated Cypher statement.
23
+
24
+ The question is:
25
+ {question}"""
26
+ CYPHER_GENERATION_PROMPT = PromptTemplate(
27
+ input_variables=["schema", "question", "concepts"], template=CYPHER_GENERATION_TEMPLATE
28
+ )
29
+
30
+ CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
31
+ The information part contains the provided information that you must use to construct an answer.
32
+ The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
33
+ Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
34
+ Here is an example:
35
+
36
+ Question: Which managers own Neo4j stocks?
37
+ Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
38
+ Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
39
+
40
+ Follow this example when generating answers.
41
+ If the provided information is empty, say that you don't know the answer.
42
+ Information:
43
+ {context}
44
+
45
+ Question: {question}
46
+ Helpful Answer:"""
47
+ CYPHER_QA_PROMPT = PromptTemplate(
48
+ input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
49
+ )
50
+
51
+ PLAN_GEN_PROMPT = """System : You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement
52
+
53
+ System : Let's first understand the problem and devise a plan to solve the problem.
54
+ Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
55
+ Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
56
+ At the end of your plan, say '<END_OF_PLAN>'"""
57
+
58
+ PLAN_MODIFICATION_PROMPT = """You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement.
59
+ To achieve this goal we are going to follow this generic plan :
60
+
61
+ ###PLAN TEMPLATE###
62
+
63
+ Plan:
64
+
65
+ 1. **Understanding the Problem**: Gather information from existing specifications and standards to thoroughly understand the technical requirement. This should help you understand the key aspects of the problem.
66
+ 2. **Gather information about latest innovations** : Gather information about the latest innovations related to the problem by looking at the most relevant research papers and list the sources.
67
+ 3. **Identifying NEW and INNOVATIVE Key Issues**: Based on the understanding of the problem, identify new and innovative key issues that could occur while trying to fulfill this requirement. Descripbe them in simple technical english. These key issues should be relevant, significant, and not yet addressed by existing solutions.
68
+ 4. **Develop Detailed Descriptions for Each Key Issue**: For each identified key issue, provide a detailed description in simple technical english, including the specific challenges and areas requiring further study.
69
+ <END_OF_PLAN>
70
+
71
+ ###END OF PLAN TEMPLATE###
72
+
73
+ Let's and devise a plan to solve the problem by adapting the PLAN TEMPLATE.
74
+ Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
75
+ Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
76
+ At the end of your plan, say '<END_OF_PLAN>' """
77
+
78
+ PLAN_GEN_PROMPT = PLAN_MODIFICATION_PROMPT
79
+
80
+ CONCEPT_SELECTION_TEMPLATE = """Task: Select the most relevant topic to the user question
81
+ Instructions:
82
+ Select the most relevant Concept to the user's question.
83
+ Concepts can ONLY be selected from the list below.
84
+
85
+ Concepts:
86
+ {concepts}
87
+
88
+ Note: Do not include any explanations or apologies in your responses.
89
+ Do not include any text except the selected concept.
90
+
91
+ The question is:
92
+ {question}"""
93
+ CONCEPT_SELECTION_PROMPT = PromptTemplate(
94
+ input_variables=["concepts", "question"], template=CONCEPT_SELECTION_TEMPLATE
95
+ )
96
+
97
+ RELEVANT_CONCEPTS_TEMPLATE = """
98
+ ## CONCEPTS ##
99
+ {concepts}
100
+ ## END OF CONCEPTS ##
101
+
102
+ Select the 20 most relevant concepts to the user query.
103
+ Output your answer as a numbered list preceeded with the header 'Concepts:'.
104
+
105
+ User query :
106
+ {user_query}
107
+ """
108
+ RELEVANT_CONCEPTS_PROMPT = ChatPromptTemplate.from_messages([
109
+ ("human", RELEVANT_CONCEPTS_TEMPLATE)
110
+ ])
111
+
112
+ SUMMARIZER_TEMPLATE = """You are a 3GPP standardization expert.
113
+ Summarize the provided document in simple technical English for other experts in the field.
114
+
115
+ Document:
116
+ {document}"""
117
+ SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
118
+ ("system", SUMMARIZER_TEMPLATE)
119
+ ])
120
+
121
+
122
+ BINARY_GRADER_TEMPLATE = """You are a grader assessing relevance of a retrieved document to a user question. \n
123
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
124
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
125
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
126
+ BINARY_GRADER_PROMPT = ChatPromptTemplate.from_messages(
127
+ [
128
+ ("system", BINARY_GRADER_TEMPLATE),
129
+ ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
130
+ ]
131
+ )
132
+
133
+
134
+ SCORE_GRADER_TEMPLATE = """Grasp and understand both the query and the document before score generation.
135
+ Then, based on your understanding and analysis quantify the relevance between the document and the query.
136
+ Give the rationale before answering.
137
+ Ouput your answer as a score ranging between 0 (irrelevant document) and 1 (completely relevant document)"""
138
+
139
+ SCORE_GRADER_PROMPT = ChatPromptTemplate.from_messages(
140
+ [
141
+ ("system", SCORE_GRADER_TEMPLATE),
142
+ ("human", "Passage: \n\n {document} \n\n User query: {query}")
143
+ ]
144
+ )
145
+
146
+ def get_initial_prompt(config: ConfigSchema, user_query : str):
147
+ if config["configurable"].get("plan_method") == "generation":
148
+ prompt = PLAN_GEN_PROMPT
149
+ elif config["configurable"].get("plan_method") == "modification":
150
+ prompt = PLAN_MODIFICATION_PROMPT
151
+ else:
152
+ raise ValueError("Incorrect plan_method, should be 'generation' or 'modification'")
153
+
154
+ user_input = user_query or input("User :")
155
+ return {"messages" : [SystemMessage(content=prompt), HumanMessage(content=user_input)]}
ki_gen/utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import getpass
3
+ import html
4
+
5
+
6
+ from typing import Annotated, Union
7
+ from typing_extensions import TypedDict
8
+
9
+ from langchain_community.graphs import Neo4jGraph
10
+ from langchain_groq import ChatGroq
11
+ from langchain_openai import ChatOpenAI
12
+
13
+ from langgraph.checkpoint.sqlite import SqliteSaver
14
+ from langgraph.checkpoint.memory import MemorySaver
15
+ from langgraph.checkpoint import base
16
+ from langgraph.graph import add_messages
17
+
18
+ memory = MemorySaver()
19
+
20
+ def format_df(df):
21
+ """
22
+ Used to display the generated plan in a nice format
23
+ Returns html code in a string
24
+ """
25
+ def format_cell(cell):
26
+ if isinstance(cell, str):
27
+ # Encode special characters, but preserve line breaks
28
+ return html.escape(cell).replace('\n', '<br>')
29
+ return cell
30
+ # Convert the DataFrame to HTML with custom CSS
31
+ formatted_df = df.map(format_cell)
32
+ html_table = formatted_df.to_html(escape=False, index=False)
33
+
34
+ # Add custom CSS to allow multiple lines and scrolling in cells
35
+ css = """
36
+ <style>
37
+ table {
38
+ border-collapse: collapse;
39
+ width: 100%;
40
+ }
41
+ th, td {
42
+ border: 1px solid black;
43
+ padding: 8px;
44
+ text-align: left;
45
+ vertical-align: top;
46
+ white-space: pre-wrap;
47
+ max-width: 300px;
48
+ max-height: 100px;
49
+ overflow-y: auto;
50
+ }
51
+ th {
52
+ background-color: #f2f2f2;
53
+ }
54
+ </style>
55
+ """
56
+
57
+ return css + html_table
58
+
59
+ def format_doc(doc: dict) -> str :
60
+ formatted_string = ""
61
+ for key in doc:
62
+ formatted_string += f"**{key}**: {doc[key]}\n"
63
+ return formatted_string
64
+
65
+
66
+
67
+ def _set_env(var: str, value: str = None):
68
+ if not os.environ.get(var):
69
+ if value:
70
+ os.environ[var] = value
71
+ else:
72
+ os.environ[var] = getpass.getpass(f"{var}: ")
73
+
74
+
75
+ def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
76
+ """
77
+ Initialize app with user api keys and sets up proxy settings
78
+ """
79
+ _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
80
+ _set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
81
+ _set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
82
+ os.environ["LANGSMITH_TRACING_V2"] = "true"
83
+ os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
84
+
85
+
86
+ def clear_memory(memory, thread_id: str = "") -> None:
87
+ """
88
+ Clears checkpointer state for a given thread_id, broken for now
89
+ TODO : fix this
90
+ """
91
+ memory = MemorySaver()
92
+
93
+ #checkpoint = base.empty_checkpoint()
94
+ #memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
95
+
96
+ def get_model(model : str = "deepseek-r1-distill-llama-70b"):
97
+ """
98
+ Wrapper to return the correct llm object depending on the 'model' param
99
+ """
100
+ if model == "gpt-4o":
101
+ llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
102
+ else:
103
+ llm = ChatGroq(model=model)
104
+ return llm
105
+
106
+
107
+ class ConfigSchema(TypedDict):
108
+ graph: Neo4jGraph
109
+ plan_method: str
110
+ use_detailed_query: bool
111
+
112
+ class State(TypedDict):
113
+ messages : Annotated[list, add_messages]
114
+ store_plan : list[str]
115
+ current_plan_step : int
116
+ valid_docs : list[str]
117
+
118
+ class DocRetrieverState(TypedDict):
119
+ messages: Annotated[list, add_messages]
120
+ query: str
121
+ docs: list[dict]
122
+ cyphers: list[str]
123
+ current_plan_step : int
124
+ valid_docs: list[Union[str, dict]]
125
+
126
+ class HumanValidationState(TypedDict):
127
+ human_validated : bool
128
+ process_steps : list[str]
129
+
130
+ def update_doc_history(left : list | None, right : list | None) -> list:
131
+ """
132
+ Reducer for the 'docs_in_processing' field.
133
+ Doesn't work currently because of bad handlinf of duplicates
134
+ TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
135
+ """
136
+ if not left:
137
+ # This shouldn't happen
138
+ left = [[]]
139
+ if not right:
140
+ right = []
141
+
142
+ for i in range(len(right)):
143
+ left[i].append(right[i])
144
+ return left
145
+
146
+
147
+ class DocProcessorState(TypedDict):
148
+ valid_docs : list[Union[str, dict]]
149
+ docs_in_processing : list
150
+ process_steps : list[Union[str,dict]]
151
+ current_process_step : int
152
+