Upload 15 files
Browse files- app.py +555 -0
- doc_explorer/embeddings_full.npy +3 -0
- doc_explorer/explorer.py +335 -0
- doc_explorer/exported_docs/.gitkeep +1 -0
- doc_explorer/vectorstore.py +151 -0
- flagged/log.csv +39 -0
- images/flowchart_graphrag.png +0 -0
- images/flowchart_graphrag_dark.png +0 -0
- images/flowchart_graphrag_final.png +0 -0
- images/graph_png.png +0 -0
- ki_gen/data_processor.py +183 -0
- ki_gen/data_retriever.py +351 -0
- ki_gen/planner.py +253 -0
- ki_gen/prompts.py +155 -0
- ki_gen/utils.py +155 -0
app.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from langchain_community.graphs import Neo4jGraph
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
|
6 |
+
from ki_gen.planner import build_planner_graph
|
7 |
+
from ki_gen.utils import clear_memory, init_app, format_df, memory
|
8 |
+
from ki_gen.prompts import get_initial_prompt
|
9 |
+
|
10 |
+
MAX_PROCESSING_STEPS = 10
|
11 |
+
|
12 |
+
print(f"WHEREVER YOU ARE THIS IS THE MEMORY INSTANCE !!!! : {type(memory)} !!!!")
|
13 |
+
|
14 |
+
def start_inference(data):
|
15 |
+
"""
|
16 |
+
Starts plan generation with user_query as input which gets displayed after
|
17 |
+
"""
|
18 |
+
config = data[config_state]
|
19 |
+
init_app(
|
20 |
+
openai_key=data[openai_api_key],
|
21 |
+
groq_key=data[groq_api_key],
|
22 |
+
langsmith_key=data[langsmith_api_key]
|
23 |
+
)
|
24 |
+
|
25 |
+
clear_memory(memory, config["configurable"].get("thread_id"))
|
26 |
+
|
27 |
+
graph = build_planner_graph(memory, config["configurable"])
|
28 |
+
with open("images/graph_png.png", "wb") as f:
|
29 |
+
f.write(graph.get_graph(xray=1).draw_mermaid_png())
|
30 |
+
|
31 |
+
print("here !")
|
32 |
+
for event in graph.stream(get_initial_prompt(config, data[user_query]), config, stream_mode="values"):
|
33 |
+
if "messages" in event:
|
34 |
+
event["messages"][-1].pretty_print()
|
35 |
+
|
36 |
+
state = graph.get_state(config)
|
37 |
+
steps = [i for i in range(1,len(state.values['store_plan'])+1)]
|
38 |
+
df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']})
|
39 |
+
return [df, graph]
|
40 |
+
|
41 |
+
def update_display(df):
|
42 |
+
"""
|
43 |
+
Displays the df after it has been generated
|
44 |
+
"""
|
45 |
+
formatted_html = format_df(df)
|
46 |
+
return {
|
47 |
+
plan_display : gr.update(visible=True, value = formatted_html),
|
48 |
+
select_step_to_modify : gr.update(visible=True, value=0),
|
49 |
+
enter_new_step : gr.update(visible=True),
|
50 |
+
submit_new_step : gr.update(visible=True),
|
51 |
+
continue_inference_btn : gr.update(visible=True)
|
52 |
+
}
|
53 |
+
|
54 |
+
def format_docs(docs: list[dict]):
|
55 |
+
formatted_results = ""
|
56 |
+
for i, doc in enumerate(docs):
|
57 |
+
formatted_results += f"\n### Document {i}\n"
|
58 |
+
for key in doc:
|
59 |
+
formatted_results += f"**{key}**: {doc[key]}\n"
|
60 |
+
return formatted_results
|
61 |
+
|
62 |
+
def continue_inference(data):
|
63 |
+
"""
|
64 |
+
Proceeds to next plan step
|
65 |
+
"""
|
66 |
+
graph = data[graph_state]
|
67 |
+
config = data[config_state]
|
68 |
+
|
69 |
+
for event in graph.stream(None, config, stream_mode="values"):
|
70 |
+
if "messages" in event:
|
71 |
+
event["messages"][-1].pretty_print()
|
72 |
+
|
73 |
+
snapshot = graph.get_state(config)
|
74 |
+
print(f"DEBUG INFO : next : {snapshot.next}")
|
75 |
+
print(f"DEBUG INFO ++ L.75: {snapshot}")
|
76 |
+
|
77 |
+
if snapshot.next and snapshot.next[0] == "human_validation":
|
78 |
+
return {
|
79 |
+
continue_inference_btn : gr.update(visible=False),
|
80 |
+
graph_state : graph,
|
81 |
+
retrieve_more_docs_btn : gr.update(visible=True),
|
82 |
+
continue_to_processing_btn : gr.update(visible=True),
|
83 |
+
human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"),
|
84 |
+
retrieved_docs_state : snapshot.values['valid_docs']
|
85 |
+
}
|
86 |
+
|
87 |
+
return {
|
88 |
+
plan_result : snapshot.values["messages"][-1].content,
|
89 |
+
graph_state : graph,
|
90 |
+
continue_inference_btn : gr.update(visible=False)
|
91 |
+
}
|
92 |
+
|
93 |
+
def continue_to_processing():
|
94 |
+
"""
|
95 |
+
Continue to doc processing configuration
|
96 |
+
"""
|
97 |
+
return {
|
98 |
+
retrieve_more_docs_btn : gr.update(visible=False),
|
99 |
+
continue_to_processing_btn : gr.update(visible=False),
|
100 |
+
human_validation_title : gr.update(visible=False),
|
101 |
+
process_data_btn : gr.update(visible=True),
|
102 |
+
process_steps_nb : gr.update(visible=True),
|
103 |
+
process_steps_title : gr.update(visible=True)
|
104 |
+
}
|
105 |
+
|
106 |
+
def retrieve_more_docs(data):
|
107 |
+
"""
|
108 |
+
Restart doc retrieval
|
109 |
+
For now we simply regenerate the cypher, it may be different because temperature != 0
|
110 |
+
"""
|
111 |
+
graph = data[graph_state]
|
112 |
+
config = data[config_state]
|
113 |
+
graph.update_state(config, {'human_validated' : False}, as_node="human_validation")
|
114 |
+
|
115 |
+
for event in graph.stream(None, config, stream_mode="values"):
|
116 |
+
if "messages" in event:
|
117 |
+
event["messages"][-1].pretty_print()
|
118 |
+
|
119 |
+
snapshot = graph.get_state(config)
|
120 |
+
print(f"DEBUG INFO : next : {snapshot.next}")
|
121 |
+
print(f"DEBUG INFO ++ L.121: {snapshot}")
|
122 |
+
|
123 |
+
return {
|
124 |
+
graph_state : graph,
|
125 |
+
human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"),
|
126 |
+
retrieved_docs_display : format_docs(snapshot.values['valid_docs'])
|
127 |
+
}
|
128 |
+
|
129 |
+
def execute_processing(*args):
|
130 |
+
"""
|
131 |
+
Execute doc processing
|
132 |
+
Args are passed as a list and not a dict for syntax convenience
|
133 |
+
"""
|
134 |
+
graph = args[-2]
|
135 |
+
config = args[-1]
|
136 |
+
nb_process_steps = args[-3]
|
137 |
+
|
138 |
+
process_steps = []
|
139 |
+
for i in range (nb_process_steps):
|
140 |
+
if args[i] == "custom":
|
141 |
+
process_steps.append({"prompt" : args[nb_process_steps + i], "context" : args[2*nb_process_steps + i], "processing_model" : args[3*nb_process_steps + i]})
|
142 |
+
else:
|
143 |
+
process_steps.append(args[i])
|
144 |
+
|
145 |
+
graph.update_state(config, {'human_validated' : True, 'process_steps' : process_steps}, as_node="human_validation")
|
146 |
+
|
147 |
+
for event in graph.stream(None, config, stream_mode="values"):
|
148 |
+
if "messages" in event:
|
149 |
+
event["messages"][-1].pretty_print()
|
150 |
+
|
151 |
+
snapshot = graph.get_state(config)
|
152 |
+
print(f"DEBUG INFO : next : {snapshot.next}")
|
153 |
+
print(f"DEBUG INFO ++ L.153: {snapshot}")
|
154 |
+
|
155 |
+
return {
|
156 |
+
plan_result : snapshot.values["messages"][-1].content,
|
157 |
+
processed_docs_state : snapshot.values["valid_docs"],
|
158 |
+
graph_state : graph,
|
159 |
+
continue_inference_btn : gr.update(visible=True),
|
160 |
+
process_steps_nb : gr.update(value=0, visible=False),
|
161 |
+
process_steps_title : gr.update(visible=False),
|
162 |
+
process_data_btn : gr.update(visible=False),
|
163 |
+
}
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
def update_config_display():
|
168 |
+
"""
|
169 |
+
Called after loading the config.json file
|
170 |
+
TODO : allow the user to specify a path to the config file
|
171 |
+
"""
|
172 |
+
with open("config.json", "r") as config_file:
|
173 |
+
config = json.load(config_file)
|
174 |
+
|
175 |
+
return {
|
176 |
+
main_llm : config["main_llm"],
|
177 |
+
plan_method : config["plan_method"],
|
178 |
+
use_detailed_query : config["use_detailed_query"],
|
179 |
+
cypher_gen_method : config["cypher_gen_method"],
|
180 |
+
validate_cypher : config["validate_cypher"],
|
181 |
+
summarization_model : config["summarize_model"],
|
182 |
+
eval_method : config["eval_method"],
|
183 |
+
eval_threshold : config["eval_threshold"],
|
184 |
+
max_docs : config["max_docs"],
|
185 |
+
compression_method : config["compression_method"],
|
186 |
+
compress_rate : config["compress_rate"],
|
187 |
+
force_tokens : config["force_tokens"],
|
188 |
+
eval_model : config["eval_model"],
|
189 |
+
srv_addr : config["graph"]["address"],
|
190 |
+
srv_usr : config["graph"]["username"],
|
191 |
+
srv_pwd : config["graph"]["password"],
|
192 |
+
openai_api_key : config["openai_api_key"],
|
193 |
+
groq_api_key : config["groq_api_key"],
|
194 |
+
langsmith_api_key : config["langsmith_api_key"]
|
195 |
+
}
|
196 |
+
|
197 |
+
|
198 |
+
def build_config(data):
|
199 |
+
"""
|
200 |
+
Build the config variable using the values inputted by the user
|
201 |
+
"""
|
202 |
+
config = {}
|
203 |
+
config["main_llm"] = data[main_llm]
|
204 |
+
config["plan_method"] = data[plan_method]
|
205 |
+
config["use_detailed_query"] = data[use_detailed_query]
|
206 |
+
config["cypher_gen_method"] = data[cypher_gen_method]
|
207 |
+
config["validate_cypher"] = data[validate_cypher]
|
208 |
+
config["summarize_model"] = data[summarization_model]
|
209 |
+
config["eval_method"] = data[eval_method]
|
210 |
+
config["eval_threshold"] = data[eval_threshold]
|
211 |
+
config["max_docs"] = data[max_docs]
|
212 |
+
config["compression_method"] = data[compression_method]
|
213 |
+
config["compress_rate"] = data[compress_rate]
|
214 |
+
config["force_tokens"] = data[force_tokens]
|
215 |
+
config["eval_model"] = data[eval_model]
|
216 |
+
config["thread_id"] = "3"
|
217 |
+
try:
|
218 |
+
neograph = Neo4jGraph(url=data[srv_addr], username=data[srv_usr], password=data[srv_pwd])
|
219 |
+
config["graph"] = neograph
|
220 |
+
except Exception as e:
|
221 |
+
raise gr.Error(f"Error when configuring the neograph server : {e}", duration=5)
|
222 |
+
gr.Info("Succesfully updated configuration !", duration=5)
|
223 |
+
return {"configurable" : config}
|
224 |
+
|
225 |
+
with gr.Blocks() as demo:
|
226 |
+
with gr.Tab("Config"):
|
227 |
+
|
228 |
+
### The config tab
|
229 |
+
|
230 |
+
gr.Markdown("## Config options setup")
|
231 |
+
|
232 |
+
gr.Markdown("### API Keys")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
openai_api_key = gr.Textbox(
|
236 |
+
label="OpenAI API Key",
|
237 |
+
type="password"
|
238 |
+
)
|
239 |
+
|
240 |
+
groq_api_key = gr.Textbox(
|
241 |
+
label="Groq API Key",
|
242 |
+
type='password'
|
243 |
+
)
|
244 |
+
|
245 |
+
langsmith_api_key = gr.Textbox(
|
246 |
+
label="LangSmith API Key",
|
247 |
+
type="password"
|
248 |
+
)
|
249 |
+
|
250 |
+
gr.Markdown('### Planner options')
|
251 |
+
with gr.Row():
|
252 |
+
main_llm = gr.Dropdown(
|
253 |
+
choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768"],
|
254 |
+
label="Main LLM",
|
255 |
+
info="Choose the LLM which will perform the generation",
|
256 |
+
value="gpt-4o"
|
257 |
+
)
|
258 |
+
with gr.Column(scale=1, min_width=600):
|
259 |
+
plan_method = gr.Dropdown(
|
260 |
+
choices=["generation", "modification"],
|
261 |
+
label="Planning method",
|
262 |
+
info="Choose how the main LLM will generate its plan",
|
263 |
+
value="modification"
|
264 |
+
)
|
265 |
+
use_detailed_query = gr.Checkbox(
|
266 |
+
label="Detail each plan step",
|
267 |
+
info="Detail each plan step before passing it for data query"
|
268 |
+
)
|
269 |
+
|
270 |
+
gr.Markdown("### Data query options")
|
271 |
+
|
272 |
+
# The options for the data processor
|
273 |
+
# TODO : remove the options for summarize and compress and let the user choose them when specifying processing steps
|
274 |
+
# (similarly to what is done for custom processing step)
|
275 |
+
|
276 |
+
with gr.Row():
|
277 |
+
with gr.Column(scale=1, min_width=300):
|
278 |
+
# Neo4j Server parameters
|
279 |
+
|
280 |
+
srv_addr = gr.Textbox(
|
281 |
+
label="Neo4j server address",
|
282 |
+
placeholder="localhost:7687"
|
283 |
+
)
|
284 |
+
srv_usr = gr.Textbox(
|
285 |
+
label="Neo4j username",
|
286 |
+
placeholder="neo4j"
|
287 |
+
)
|
288 |
+
srv_pwd = gr.Textbox(
|
289 |
+
label="Neo4j password",
|
290 |
+
placeholder="<Password>"
|
291 |
+
)
|
292 |
+
|
293 |
+
with gr.Column(scale=1, min_width=300):
|
294 |
+
cypher_gen_method = gr.Dropdown(
|
295 |
+
choices=["auto", "guided"],
|
296 |
+
label="Cypher generation method",
|
297 |
+
)
|
298 |
+
validate_cypher = gr.Checkbox(
|
299 |
+
label="Validate cypher using graph Schema"
|
300 |
+
)
|
301 |
+
|
302 |
+
summarization_model = gr.Dropdown(
|
303 |
+
choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768", "llama3-70b-8192"],
|
304 |
+
label="Summarization LLM",
|
305 |
+
info="Choose the LLM which will perform the summaries"
|
306 |
+
)
|
307 |
+
|
308 |
+
with gr.Column(scale=1, min_width=300):
|
309 |
+
eval_method = gr.Dropdown(
|
310 |
+
choices=["binary", "score"],
|
311 |
+
label="Retrieved docs evaluation method",
|
312 |
+
info="Evaluation method of retrieved docs"
|
313 |
+
)
|
314 |
+
|
315 |
+
eval_model = gr.Dropdown(
|
316 |
+
choices = ["gpt-4o", "mixtral-8x7b-32768"],
|
317 |
+
label = "Evaluation model",
|
318 |
+
info = "The LLM to use to evaluate the relevance of retrieved docs",
|
319 |
+
value = "mixtral-8x7b-32768"
|
320 |
+
)
|
321 |
+
|
322 |
+
eval_threshold = gr.Slider(
|
323 |
+
minimum=0,
|
324 |
+
maximum=1,
|
325 |
+
value=0.7,
|
326 |
+
label="Eval threshold",
|
327 |
+
info="Score above which a doc is considered relevant",
|
328 |
+
step=0.01,
|
329 |
+
visible=False
|
330 |
+
)
|
331 |
+
|
332 |
+
def eval_method_changed(selection):
|
333 |
+
if selection == "score":
|
334 |
+
return gr.update(visible=True)
|
335 |
+
return gr.update(visible=False)
|
336 |
+
eval_method.change(eval_method_changed, inputs=eval_method, outputs=eval_threshold)
|
337 |
+
|
338 |
+
max_docs= gr.Slider(
|
339 |
+
minimum=0,
|
340 |
+
maximum = 30,
|
341 |
+
value = 15,
|
342 |
+
label="Max docs",
|
343 |
+
info="Maximum number of docs to be retrieved at each query",
|
344 |
+
step=0.01
|
345 |
+
)
|
346 |
+
|
347 |
+
with gr.Column(scale=1, min_width=300):
|
348 |
+
compression_method = gr.Dropdown(
|
349 |
+
choices=["llm_lingua2", "llm_lingua"],
|
350 |
+
label="Compression method",
|
351 |
+
value="llm_lingua2"
|
352 |
+
)
|
353 |
+
|
354 |
+
with gr.Row():
|
355 |
+
|
356 |
+
# Add compression rate configuration with a gr.slider
|
357 |
+
compress_rate = gr.Slider(
|
358 |
+
minimum = 0,
|
359 |
+
maximum = 1,
|
360 |
+
value = 0.33,
|
361 |
+
label="Compression rate",
|
362 |
+
info="Compression rate",
|
363 |
+
step = 0.01
|
364 |
+
)
|
365 |
+
|
366 |
+
# Add gr.CheckboxGroup to choose force_tokens
|
367 |
+
force_tokens = gr.CheckboxGroup(
|
368 |
+
choices=['\n', '?', '.', '!', ','],
|
369 |
+
value=[],
|
370 |
+
label="Force tokens",
|
371 |
+
info="Tokens to keep during compression",
|
372 |
+
)
|
373 |
+
|
374 |
+
with gr.Row():
|
375 |
+
btn_update_config = gr.Button(value="Update config")
|
376 |
+
load_config_json = gr.Button(value="Load config from JSON")
|
377 |
+
|
378 |
+
with gr.Row():
|
379 |
+
debug_info = gr.Button(value="Print debug info")
|
380 |
+
|
381 |
+
config_state = gr.State(value={})
|
382 |
+
|
383 |
+
|
384 |
+
btn_update_config.click(
|
385 |
+
build_config,
|
386 |
+
inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \
|
387 |
+
compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs},
|
388 |
+
outputs=config_state
|
389 |
+
)
|
390 |
+
load_config_json.click(
|
391 |
+
update_config_display,
|
392 |
+
outputs={main_llm, plan_method, use_detailed_query, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, \
|
393 |
+
max_docs, compress_rate, compression_method, force_tokens, eval_model, srv_addr, srv_usr, srv_pwd, openai_api_key, langsmith_api_key, groq_api_key}
|
394 |
+
).then(
|
395 |
+
build_config,
|
396 |
+
inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \
|
397 |
+
compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs},
|
398 |
+
outputs=config_state
|
399 |
+
)
|
400 |
+
|
401 |
+
# Print config variable in the terminal
|
402 |
+
debug_info.click(lambda x : print(x), inputs=config_state)
|
403 |
+
|
404 |
+
with gr.Tab("Inference"):
|
405 |
+
### Inference tab
|
406 |
+
|
407 |
+
graph_state = gr.State()
|
408 |
+
user_query = gr.Textbox(label = "Your query")
|
409 |
+
launch_inference = gr.Button(value="Generate plan")
|
410 |
+
|
411 |
+
with gr.Row():
|
412 |
+
dataframe_plan = gr.Dataframe(visible = False)
|
413 |
+
plan_display = gr.HTML(visible = False, label="Generated plan")
|
414 |
+
|
415 |
+
with gr.Column():
|
416 |
+
|
417 |
+
# Lets the user modify steps of the plan. Underlying logic not implemented yet
|
418 |
+
# TODO : implement this
|
419 |
+
with gr.Row():
|
420 |
+
select_step_to_modify = gr.Number(visible= False, label="Select a plan step to modify", value=0)
|
421 |
+
submit_new_step = gr.Button(visible = False, value="Submit new step")
|
422 |
+
enter_new_step = gr.Textbox(visible=False, label="Modify the plan step")
|
423 |
+
|
424 |
+
with gr.Row():
|
425 |
+
human_validation_title = gr.Markdown(visible=False)
|
426 |
+
retrieve_more_docs_btn = gr.Button(value="Retrieve more docs", visible=False)
|
427 |
+
continue_to_processing_btn = gr.Button(value="Proceed to data processing", visible=False)
|
428 |
+
|
429 |
+
with gr.Row():
|
430 |
+
with gr.Column():
|
431 |
+
|
432 |
+
process_steps_title = gr.Markdown("#### Data processing steps", visible=False)
|
433 |
+
process_steps_nb = gr.Number(label="Number of processing steps", value = 0, precision=0, step = 1, visible=False)
|
434 |
+
|
435 |
+
def get_process_step_names():
|
436 |
+
return ["summarize", "compress", "custom"]
|
437 |
+
|
438 |
+
# The gr.render decorator allows the code inside the following function to be rerun everytime the 'inputs' variable is modified
|
439 |
+
# /!\ All event listeners that use variables defined inside a gr.render function must be defined inside that same function
|
440 |
+
# ref : https://www.gradio.app/docs/gradio/render
|
441 |
+
@gr.render(inputs=process_steps_nb)
|
442 |
+
def processing(nb):
|
443 |
+
with gr.Row():
|
444 |
+
process_step_names = get_process_step_names()
|
445 |
+
dropdowns = []
|
446 |
+
textboxes = []
|
447 |
+
usable_elements = []
|
448 |
+
processing_models = []
|
449 |
+
for i in range(nb):
|
450 |
+
with gr.Column():
|
451 |
+
dropdown = gr.Dropdown(key = f"d{i}", choices=process_step_names, label=f"Data processing step {i+1}")
|
452 |
+
dropdowns.append(dropdown)
|
453 |
+
|
454 |
+
textbox = gr.Textbox(
|
455 |
+
key = f"t{i}",
|
456 |
+
value="",
|
457 |
+
placeholder="Your custom prompt",
|
458 |
+
visible=True, min_width=300)
|
459 |
+
textboxes.append(textbox)
|
460 |
+
|
461 |
+
usable_element = gr.Dropdown(
|
462 |
+
key = f"u{i}",
|
463 |
+
choices = [(j) for j in range(i+1)],
|
464 |
+
label="Elements passed to the LLM for this process step",
|
465 |
+
multiselect=True,
|
466 |
+
)
|
467 |
+
usable_elements.append(usable_element)
|
468 |
+
|
469 |
+
processing_model = gr.Dropdown(
|
470 |
+
key = f"m{i}",
|
471 |
+
label="The LLM that will execute this step",
|
472 |
+
visible=True,
|
473 |
+
choices=["gpt-4o", "mixtral-8x7b-32768", "llama3-70b-8182"]
|
474 |
+
)
|
475 |
+
processing_models.append(processing_model)
|
476 |
+
|
477 |
+
dropdown.change(
|
478 |
+
fn=lambda process_name : [gr.update(visible=(process_name=="custom")), gr.update(visible=(process_name=='custom')), gr.update(visible=(process_name=='custom'))],
|
479 |
+
inputs=dropdown,
|
480 |
+
outputs=[textbox, usable_element, processing_model]
|
481 |
+
)
|
482 |
+
|
483 |
+
process_data_btn.click(
|
484 |
+
execute_processing,
|
485 |
+
inputs= dropdowns + textboxes + usable_elements + processing_models + [process_steps_nb, graph_state, config_state],
|
486 |
+
outputs={plan_result, processed_docs_state, graph_state, continue_inference_btn, process_steps_nb, process_steps_title, process_data_btn}
|
487 |
+
)
|
488 |
+
|
489 |
+
process_data_btn = gr.Button(value="Process retrieved docs", visible=False)
|
490 |
+
|
491 |
+
continue_inference_btn = gr.Button(value="Proceed to next plan step", visible=False)
|
492 |
+
plan_result = gr.Markdown(visible = True, label="Result of last plan step")
|
493 |
+
|
494 |
+
with gr.Tab("Retrieved Docs"):
|
495 |
+
retrieved_docs_state = gr.State([])
|
496 |
+
with gr.Row():
|
497 |
+
gr.Markdown("# Retrieved Docs")
|
498 |
+
retrieved_docs_btn = gr.Button("Display retrieved docs")
|
499 |
+
retrieved_docs_display = gr.Markdown()
|
500 |
+
|
501 |
+
processed_docs_state = gr.State([])
|
502 |
+
with gr.Row():
|
503 |
+
gr.Markdown("# Processed Docs")
|
504 |
+
processed_docs_btn = gr.Button("Display processed docs")
|
505 |
+
processed_docs_display = gr.Markdown()
|
506 |
+
|
507 |
+
continue_inference_btn.click(
|
508 |
+
continue_inference,
|
509 |
+
inputs={graph_state, config_state},
|
510 |
+
outputs={continue_inference_btn, graph_state, retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, plan_result, retrieved_docs_state}
|
511 |
+
)
|
512 |
+
|
513 |
+
launch_inference.click(
|
514 |
+
start_inference,
|
515 |
+
inputs={config_state, user_query, openai_api_key, groq_api_key, langsmith_api_key},
|
516 |
+
outputs=[dataframe_plan, graph_state]
|
517 |
+
).then(
|
518 |
+
update_display,
|
519 |
+
inputs=dataframe_plan,
|
520 |
+
outputs={plan_display, select_step_to_modify, enter_new_step, submit_new_step, continue_inference_btn}
|
521 |
+
)
|
522 |
+
|
523 |
+
retrieve_more_docs_btn.click(
|
524 |
+
retrieve_more_docs,
|
525 |
+
inputs={graph_state, config_state},
|
526 |
+
outputs={graph_state, human_validation_title, retrieved_docs_display}
|
527 |
+
)
|
528 |
+
continue_to_processing_btn.click(
|
529 |
+
continue_to_processing,
|
530 |
+
outputs={retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, process_data_btn, process_steps_nb, process_steps_title}
|
531 |
+
)
|
532 |
+
retrieved_docs_btn.click(
|
533 |
+
fn=lambda docs : format_docs(docs),
|
534 |
+
inputs=retrieved_docs_state,
|
535 |
+
outputs=retrieved_docs_display
|
536 |
+
)
|
537 |
+
processed_docs_btn.click(
|
538 |
+
fn=lambda docs : format_docs(docs),
|
539 |
+
inputs=processed_docs_state,
|
540 |
+
outputs=processed_docs_display
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
test_process_steps = gr.Button(value="Test process steps")
|
545 |
+
test_process_steps.click(
|
546 |
+
lambda : [gr.update(visible = True), gr.update(visible=True)],
|
547 |
+
outputs=[process_steps_nb, process_steps_title]
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
demo.launch()
|
555 |
+
|
doc_explorer/embeddings_full.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf8ae23f82d734adab5810858bb55c2f13edb06795637b6fd85ada823d722527
|
3 |
+
size 55693440
|
doc_explorer/explorer.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from vectorstore import FAISSVectorStore
|
3 |
+
from langchain_community.graphs import Neo4jGraph
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import html
|
7 |
+
import pandas as pd
|
8 |
+
import time
|
9 |
+
|
10 |
+
time.sleep(30)
|
11 |
+
|
12 |
+
os.environ["http_proxy"] = "185.46.212.98:80"
|
13 |
+
os.environ["https_proxy"] = "185.46.212.98:80"
|
14 |
+
os.environ["NO_PROXY"] = "localhost"
|
15 |
+
|
16 |
+
neo4j_graph = Neo4jGraph(
|
17 |
+
url=os.getenv("NEO4J_URI", "bolt://localhost:7999"),
|
18 |
+
username=os.getenv("NEO4J_USERNAME", "neo4j"),
|
19 |
+
password=os.getenv("NEO4J_PASSWORD", "graph_test")
|
20 |
+
)
|
21 |
+
|
22 |
+
# Requires ~1GB RAM
|
23 |
+
vector_store = FAISSVectorStore(model_name='Alibaba-NLP/gte-large-en-v1.5', dimension=1024, trust_remote_code=True, embedding_file="/usr/src/app/doc_explorer/embeddings_full.npy")
|
24 |
+
|
25 |
+
# Get document types from Neo4j database
|
26 |
+
def get_document_types():
|
27 |
+
query = """
|
28 |
+
MATCH (n)
|
29 |
+
RETURN DISTINCT labels(n) AS document_type
|
30 |
+
"""
|
31 |
+
result = neo4j_graph.query(query)
|
32 |
+
return [row["document_type"][0] for row in result]
|
33 |
+
|
34 |
+
def search(query, doc_types, use_mmr, lambda_param, top_k):
|
35 |
+
results, node_ids = vector_store.similarity_search(
|
36 |
+
query,
|
37 |
+
k=top_k,
|
38 |
+
use_mmr=use_mmr,
|
39 |
+
lambda_param=lambda_param if use_mmr else None,
|
40 |
+
doc_types=doc_types,
|
41 |
+
neo4j_graph=neo4j_graph
|
42 |
+
)
|
43 |
+
|
44 |
+
formatted_results = []
|
45 |
+
formatted_choices = []
|
46 |
+
for i, result in enumerate(results):
|
47 |
+
formatted_results.append(f"{i}. {result['document']} (Score: {result['score']:.4f})")
|
48 |
+
formatted_choices.append(f"{i}. {str(result['document'])[:100]} (Score: {result['score']:.4f})")
|
49 |
+
return formatted_results, gr.update(choices=formatted_choices, value=[]), node_ids
|
50 |
+
|
51 |
+
def get_docs_from_ids(graph_data : dict):
|
52 |
+
node_ids = [node["id"] for node in graph_data["nodes"]]
|
53 |
+
print(node_ids)
|
54 |
+
query = """
|
55 |
+
MATCH (n)
|
56 |
+
WHERE n.id IN $node_ids
|
57 |
+
RETURN n.id AS id, n AS doc, labels(n) AS category
|
58 |
+
"""
|
59 |
+
|
60 |
+
return neo4j_graph.query(query, {"node_ids" : node_ids}), graph_data["edges"]
|
61 |
+
|
62 |
+
def get_neighbors_and_graph_data(selected_documents, node_ids, graph_data):
|
63 |
+
if not selected_documents:
|
64 |
+
return "No documents selected.", json.dumps(graph_data), graph_data
|
65 |
+
|
66 |
+
selected_indices = [int(doc.split('.')[0]) - 1 for doc in selected_documents]
|
67 |
+
selected_node_ids = [node_ids[i] for i in selected_indices]
|
68 |
+
|
69 |
+
query = """
|
70 |
+
MATCH (n)-[r]-(neighbor)
|
71 |
+
WHERE n.id IN $node_ids
|
72 |
+
RETURN n.id AS source_id, n AS source_text, labels(n) AS source_type,
|
73 |
+
neighbor.id AS neighbor_id, neighbor AS neighbor_text,
|
74 |
+
labels(neighbor) AS neighbor_type, type(r) AS relationship_type
|
75 |
+
"""
|
76 |
+
results = neo4j_graph.query(query, {"node_ids": selected_node_ids})
|
77 |
+
|
78 |
+
if not results:
|
79 |
+
return "No neighbors found for the selected documents.", "[]"
|
80 |
+
|
81 |
+
neighbor_info = {}
|
82 |
+
node_set = set([node["id"] for node in graph_data["nodes"]])
|
83 |
+
|
84 |
+
for row in results:
|
85 |
+
source_id = row['source_id']
|
86 |
+
if source_id not in neighbor_info:
|
87 |
+
neighbor_info[source_id] = {
|
88 |
+
'source_type': row["source_type"][0],
|
89 |
+
'source_text': row['source_text'],
|
90 |
+
'neighbors': []
|
91 |
+
}
|
92 |
+
if source_id not in node_set:
|
93 |
+
graph_data["nodes"].append({
|
94 |
+
"id": source_id,
|
95 |
+
"label": str(row['source_text'])[:30] + "...",
|
96 |
+
"group": row['source_type'][0],
|
97 |
+
"title": f"<div class='node-tooltip'><h3>{row['source_type'][0]}</h3><p>{row['source_text']}</p></div>",
|
98 |
+
})
|
99 |
+
node_set.add(source_id)
|
100 |
+
|
101 |
+
neighbor_info[source_id]['neighbors'].append(
|
102 |
+
f"[{row['relationship_type']}] [{row['neighbor_type'][0]}] {str(row['neighbor_text'])[:200]}"
|
103 |
+
)
|
104 |
+
|
105 |
+
if row['neighbor_id'] not in node_set:
|
106 |
+
graph_data["nodes"].append({
|
107 |
+
"id": row['neighbor_id'],
|
108 |
+
"label": str(row['neighbor_text'])[:30] + "...",
|
109 |
+
"group": row['neighbor_type'][0],
|
110 |
+
"title": f"<div class='node-tooltip'><h3>{row['neighbor_type'][0]}</h3><p>{html.escape(str(row['neighbor_text']))}</p></div>",
|
111 |
+
})
|
112 |
+
node_set.add(row['neighbor_id'])
|
113 |
+
|
114 |
+
edge = {
|
115 |
+
"from": source_id,
|
116 |
+
"to" : row['neighbor_id'],
|
117 |
+
"label": row['relationship_type']
|
118 |
+
}
|
119 |
+
if edge not in graph_data['edges']:
|
120 |
+
graph_data['edges'].append(edge)
|
121 |
+
|
122 |
+
output = []
|
123 |
+
for source_id, info in neighbor_info.items():
|
124 |
+
output.append(f"Neighbors for: [{info['source_type']}] {str(info['source_text'])[:100]}")
|
125 |
+
output.extend(info['neighbors'])
|
126 |
+
output.append("\n\n") # Empty line for separation
|
127 |
+
|
128 |
+
formatted_choices = []
|
129 |
+
node_ids = []
|
130 |
+
for i, node in enumerate(graph_data['nodes']):
|
131 |
+
formatted_choices.append(f"{i+1}. {str(node['label'])})")
|
132 |
+
node_ids.append(node['id'])
|
133 |
+
|
134 |
+
return "\n".join(output), json.dumps(graph_data), graph_data, gr.update(choices=formatted_choices, value=[]), node_ids
|
135 |
+
|
136 |
+
def save_docs_to_excel(exported_docs : list[dict], exported_relationships : list[dict]):
|
137 |
+
cleaned_docs = [dict(doc['doc'], **{'id': doc['id'], 'category': doc['category'][0], "relationships" : ""}) for doc in exported_docs]
|
138 |
+
for relationship in exported_relationships:
|
139 |
+
for doc in cleaned_docs:
|
140 |
+
if doc['id'] == relationship['from']:
|
141 |
+
doc["relationships"] += f"[{relationship['label']}] {relationship['to']}\n"
|
142 |
+
|
143 |
+
df = pd.DataFrame(cleaned_docs)
|
144 |
+
df.to_excel("doc_explorer/exported_docs/docs.xlsx")
|
145 |
+
return gr.update(value="doc_explorer/exported_docs/docs.xlsx", visible=True)
|
146 |
+
|
147 |
+
# JavaScript code for graph visualization
|
148 |
+
js_code = """
|
149 |
+
function(graph_data_str) {
|
150 |
+
if (!graph_data_str) return;
|
151 |
+
const container = document.getElementById('graph-container');
|
152 |
+
container.innerHTML = '';
|
153 |
+
let data;
|
154 |
+
try {
|
155 |
+
data = JSON.parse(graph_data_str);
|
156 |
+
} catch (error) {
|
157 |
+
console.error("Failed to parse graph data:", error);
|
158 |
+
container.innerHTML = "Error: Failed to load graph data.";
|
159 |
+
return;
|
160 |
+
}
|
161 |
+
|
162 |
+
data.nodes.forEach(node => {
|
163 |
+
const div = document.createElement('div');
|
164 |
+
div.innerHTML = node.title;
|
165 |
+
node.title = div.firstChild;
|
166 |
+
});
|
167 |
+
|
168 |
+
const nodes = new vis.DataSet(data.nodes);
|
169 |
+
const edges = new vis.DataSet(data.edges);
|
170 |
+
const options = {
|
171 |
+
nodes: {
|
172 |
+
shape: 'dot',
|
173 |
+
size: 16,
|
174 |
+
font: {
|
175 |
+
size: 12,
|
176 |
+
color: '#000000'
|
177 |
+
},
|
178 |
+
borderWidth: 2
|
179 |
+
},
|
180 |
+
edges: {
|
181 |
+
width: 1,
|
182 |
+
font: {
|
183 |
+
size: 10,
|
184 |
+
align: 'middle'
|
185 |
+
},
|
186 |
+
color: { color: '#7A7A7A', hover: '#2B7CE9' }
|
187 |
+
},
|
188 |
+
physics: {
|
189 |
+
forceAtlas2Based: {
|
190 |
+
gravitationalConstant: -26,
|
191 |
+
centralGravity: 0.005,
|
192 |
+
springLength: 230,
|
193 |
+
springConstant: 0.18
|
194 |
+
},
|
195 |
+
maxVelocity: 146,
|
196 |
+
solver: 'forceAtlas2Based',
|
197 |
+
timestep: 0.35,
|
198 |
+
stabilization: { iterations: 150 }
|
199 |
+
},
|
200 |
+
interaction: {
|
201 |
+
hover: true,
|
202 |
+
tooltipDelay: 200
|
203 |
+
}
|
204 |
+
};
|
205 |
+
const network = new vis.Network(container, { nodes: nodes, edges: edges }, options);
|
206 |
+
}
|
207 |
+
"""
|
208 |
+
|
209 |
+
head = """
|
210 |
+
<script type="text/javascript" src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
|
211 |
+
<link href="https://unpkg.com/vis-network/styles/vis-network.min.css" rel="stylesheet" type="text/css" />
|
212 |
+
"""
|
213 |
+
|
214 |
+
custom_css = """
|
215 |
+
#graph-container {
|
216 |
+
border: 1px solid #ddd;
|
217 |
+
border-radius: 4px;
|
218 |
+
}
|
219 |
+
.vis-tooltip {
|
220 |
+
font-family: Arial, sans-serif;
|
221 |
+
padding: 10px;
|
222 |
+
border-radius: 4px;
|
223 |
+
background-color: rgba(255, 255, 255, 0.9);
|
224 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
225 |
+
max-width: 300px;
|
226 |
+
color: #333;
|
227 |
+
word-wrap: break-word;
|
228 |
+
overflow-wrap: break-word;
|
229 |
+
}
|
230 |
+
.node-tooltip {
|
231 |
+
width: 100%;
|
232 |
+
}
|
233 |
+
.node-tooltip h3 {
|
234 |
+
margin: 0 0 5px 0;
|
235 |
+
font-size: 14px;
|
236 |
+
color: #333;
|
237 |
+
}
|
238 |
+
.node-tooltip p {
|
239 |
+
margin: 0;
|
240 |
+
font-size: 12px;
|
241 |
+
color: #666;
|
242 |
+
white-space: normal;
|
243 |
+
}
|
244 |
+
"""
|
245 |
+
|
246 |
+
|
247 |
+
with gr.Blocks(head=head, css=custom_css) as demo:
|
248 |
+
with gr.Tab("Search"):
|
249 |
+
|
250 |
+
gr.Markdown("# Document Search Engine")
|
251 |
+
gr.Markdown("Enter a query to search for similar documents. You can filter by document type and use MMR for diverse results.")
|
252 |
+
|
253 |
+
with gr.Row():
|
254 |
+
with gr.Column(scale=3):
|
255 |
+
query_input = gr.Textbox(label="Enter your query")
|
256 |
+
doc_type_input = gr.Dropdown(choices=get_document_types(), label="Select document type", multiselect=True)
|
257 |
+
with gr.Column(scale=2):
|
258 |
+
mmr_input = gr.Checkbox(label="Use MMR for diverse results")
|
259 |
+
lambda_input = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Lambda parameter (MMR diversity)", visible=False)
|
260 |
+
top_k_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of results")
|
261 |
+
|
262 |
+
search_button = gr.Button("Search")
|
263 |
+
results_output = gr.Textbox(label="Search Results")
|
264 |
+
|
265 |
+
selected_documents = gr.Dropdown(label="Select documents to view their neighbors", choices=[], multiselect=True, interactive=True)
|
266 |
+
|
267 |
+
with gr.Row():
|
268 |
+
neighbor_search_button = gr.Button("Find Neighbors")
|
269 |
+
send_to_export = gr.Button("Send docs to export Tab")
|
270 |
+
|
271 |
+
neighbors_output = gr.Textbox(label="Document Neighbors")
|
272 |
+
|
273 |
+
graph_data_state = gr.State({"nodes": [], "edges": []})
|
274 |
+
graph_data_str = gr.Textbox(visible=False)
|
275 |
+
graph_container = gr.HTML('<div id="graph-container" style="height: 600px;"> Hey ! </div>')
|
276 |
+
|
277 |
+
node_ids = gr.State([])
|
278 |
+
exported_docs = gr.State([])
|
279 |
+
exported_relationships = gr.State([])
|
280 |
+
|
281 |
+
def update_lambda_visibility(use_mmr):
|
282 |
+
return gr.update(visible=use_mmr)
|
283 |
+
|
284 |
+
mmr_input.change(fn=update_lambda_visibility, inputs=mmr_input, outputs=lambda_input)
|
285 |
+
|
286 |
+
search_button.click(
|
287 |
+
fn=search,
|
288 |
+
inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input],
|
289 |
+
outputs=[results_output, selected_documents, node_ids]
|
290 |
+
)
|
291 |
+
|
292 |
+
neighbor_search_button.click(
|
293 |
+
fn=get_neighbors_and_graph_data,
|
294 |
+
inputs=[selected_documents, node_ids, graph_data_state],
|
295 |
+
outputs=[neighbors_output, graph_data_str, graph_data_state, selected_documents, node_ids]
|
296 |
+
).then(
|
297 |
+
fn=None,
|
298 |
+
inputs=graph_data_str,
|
299 |
+
outputs=None,
|
300 |
+
js=js_code,
|
301 |
+
)
|
302 |
+
|
303 |
+
send_to_export.click(
|
304 |
+
fn=get_docs_from_ids,
|
305 |
+
inputs=graph_data_state,
|
306 |
+
outputs=[exported_docs, exported_relationships]
|
307 |
+
)
|
308 |
+
# gr.Examples(
|
309 |
+
# examples=[
|
310 |
+
# ["What is machine learning?", "Article", True, 0.5, 5],
|
311 |
+
# ["How to implement a neural network?", "Tutorial", False, 0.5, 3],
|
312 |
+
# ["Latest advancements in NLP", "Research Paper", True, 0.7, 10]
|
313 |
+
# ],
|
314 |
+
# inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input]
|
315 |
+
# )
|
316 |
+
with gr.Tab("Export"):
|
317 |
+
with gr.Row():
|
318 |
+
exported_docs_btn = gr.Button("Display exported docs")
|
319 |
+
exported_excel_btn = gr.Button("Export to excel")
|
320 |
+
exported_excel = gr.File(visible=False)
|
321 |
+
|
322 |
+
exported_docs_display = gr.Markdown(visible=False)
|
323 |
+
|
324 |
+
exported_docs_btn.click(
|
325 |
+
fn= lambda docs: gr.update(value='\n\n'.join([f"[{doc['category'][0]}]\n{doc['doc']}\n\n" for doc in docs]), visible=True),
|
326 |
+
inputs=exported_docs,
|
327 |
+
outputs=exported_docs_display
|
328 |
+
)
|
329 |
+
exported_excel_btn.click(
|
330 |
+
fn=save_docs_to_excel,
|
331 |
+
inputs=[exported_docs, exported_relationships],
|
332 |
+
outputs=exported_excel
|
333 |
+
)
|
334 |
+
|
335 |
+
demo.launch()
|
doc_explorer/exported_docs/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
a
|
doc_explorer/vectorstore.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
from langchain_community.graphs import Neo4jGraph
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
class FAISSVectorStore:
|
9 |
+
def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False):
|
10 |
+
self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None
|
11 |
+
self.index = faiss.IndexFlatIP(dimension)
|
12 |
+
self.dimension = dimension
|
13 |
+
if embedding_file:
|
14 |
+
self.load_embeddings(embedding_file)
|
15 |
+
|
16 |
+
def load_embeddings(self, file_path: str):
|
17 |
+
if file_path.endswith('.pkl'):
|
18 |
+
with open(file_path, 'rb') as f:
|
19 |
+
embeddings = pickle.load(f)
|
20 |
+
elif file_path.endswith('.npy'):
|
21 |
+
embeddings = np.load(file_path)
|
22 |
+
else:
|
23 |
+
raise ValueError("Unsupported file format. Use .pkl or .npy")
|
24 |
+
|
25 |
+
self.add_embeddings(embeddings)
|
26 |
+
|
27 |
+
def add_embeddings(self, embeddings: np.ndarray):
|
28 |
+
faiss.normalize_L2(embeddings)
|
29 |
+
self.index.add(embeddings)
|
30 |
+
|
31 |
+
def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None):
|
32 |
+
query_vector = self.model.encode([query])
|
33 |
+
faiss.normalize_L2(query_vector)
|
34 |
+
|
35 |
+
if use_mmr:
|
36 |
+
return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types)
|
37 |
+
else:
|
38 |
+
return self._simple_search(query_vector, k, neo4j_graph, doc_types)
|
39 |
+
|
40 |
+
def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]:
|
41 |
+
distances, indices = self.index.search(query_vector, k)
|
42 |
+
|
43 |
+
results = []
|
44 |
+
results_idx = []
|
45 |
+
for i, idx in enumerate(indices[0]):
|
46 |
+
document = self._get_text_by_index(neo4j_graph, idx, doc_types)
|
47 |
+
if document is not None:
|
48 |
+
results.append({
|
49 |
+
'document': document,
|
50 |
+
'score': distances[0][i]
|
51 |
+
})
|
52 |
+
results_idx.append(idx)
|
53 |
+
|
54 |
+
return results, results_idx
|
55 |
+
|
56 |
+
def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]:
|
57 |
+
initial_k = min(k * 2, self.index.ntotal)
|
58 |
+
distances, indices = self.index.search(query_vector, initial_k)
|
59 |
+
|
60 |
+
# Reconstruct embeddings for the initial results
|
61 |
+
initial_embeddings = self._reconstruct_embeddings(indices[0])
|
62 |
+
|
63 |
+
selected_indices = []
|
64 |
+
unselected_indices = list(range(len(indices[0])))
|
65 |
+
|
66 |
+
for _ in range(min(k, len(indices[0]))):
|
67 |
+
mmr_scores = []
|
68 |
+
for i in unselected_indices:
|
69 |
+
if not selected_indices:
|
70 |
+
mmr_scores.append((i, distances[0][i]))
|
71 |
+
else:
|
72 |
+
embedding_i = initial_embeddings[i]
|
73 |
+
redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices)
|
74 |
+
mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy))
|
75 |
+
|
76 |
+
selected_idx = max(mmr_scores, key=lambda x: x[1])[0]
|
77 |
+
selected_indices.append(selected_idx)
|
78 |
+
unselected_indices.remove(selected_idx)
|
79 |
+
|
80 |
+
results = []
|
81 |
+
results_idx = []
|
82 |
+
for idx in selected_indices:
|
83 |
+
document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types)
|
84 |
+
if document is not None:
|
85 |
+
results.append({
|
86 |
+
'document': document,
|
87 |
+
'score': distances[0][idx]
|
88 |
+
})
|
89 |
+
results_idx.append(idx)
|
90 |
+
|
91 |
+
return results, results_idx
|
92 |
+
|
93 |
+
def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray:
|
94 |
+
return self.index.reconstruct_batch(indices)
|
95 |
+
|
96 |
+
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
|
97 |
+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
98 |
+
|
99 |
+
def _get_text_by_index(self, neo4j_graph, index, doc_types):
|
100 |
+
if doc_types is None:
|
101 |
+
query = f"""
|
102 |
+
MATCH (n)
|
103 |
+
WHERE n.id = $index
|
104 |
+
RETURN n AS document, labels(n) AS document_type, n.id AS node_id
|
105 |
+
"""
|
106 |
+
result = neo4j_graph.query(query, {"index": index})
|
107 |
+
else:
|
108 |
+
for doc_type in doc_types:
|
109 |
+
query = f"""
|
110 |
+
MATCH (n:{doc_type})
|
111 |
+
WHERE n.id = $index
|
112 |
+
RETURN n AS document, labels(n) AS document_type, n.id AS node_id
|
113 |
+
"""
|
114 |
+
result = neo4j_graph.query(query, {"index": index})
|
115 |
+
if result:
|
116 |
+
break
|
117 |
+
|
118 |
+
if result:
|
119 |
+
return f"[{result[0]['document_type'][0]}] {result[0]['document']}"
|
120 |
+
return None
|
121 |
+
|
122 |
+
|
123 |
+
# Example usage
|
124 |
+
if __name__ == "__main__":
|
125 |
+
# Initialize the vector store with embedding file
|
126 |
+
vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl") # or .npy file
|
127 |
+
|
128 |
+
# Initialize Neo4jGraph (replace with your actual Neo4j connection details)
|
129 |
+
neo4j_graph = Neo4jGraph(
|
130 |
+
url="bolt://localhost:7687",
|
131 |
+
username="neo4j",
|
132 |
+
password="password"
|
133 |
+
)
|
134 |
+
|
135 |
+
# Perform a similarity search with and without MMR
|
136 |
+
query = "How to start a long journey"
|
137 |
+
results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph)
|
138 |
+
results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph)
|
139 |
+
|
140 |
+
# Print the results
|
141 |
+
print(f"Top 5 similar texts for query: '{query}' (without MMR)")
|
142 |
+
for i, result in enumerate(results_simple, 1):
|
143 |
+
print(f"{i}. Text: {result['text']}")
|
144 |
+
print(f" Score: {result['score']}")
|
145 |
+
print()
|
146 |
+
|
147 |
+
print(f"Top 5 similar texts for query: '{query}' (with MMR)")
|
148 |
+
for i, result in enumerate(results_mmr, 1):
|
149 |
+
print(f"{i}. Text: {result['text']}")
|
150 |
+
print(f" Score: {result['score']}")
|
151 |
+
print()
|
flagged/log.csv
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output,flag,username,timestamp
|
2 |
+
"'
|
3 |
+
<style>
|
4 |
+
table {
|
5 |
+
border-collapse: collapse;
|
6 |
+
width: 100%;
|
7 |
+
}
|
8 |
+
th, td {
|
9 |
+
border: 1px solid black;
|
10 |
+
padding: 8px;
|
11 |
+
text-align: left;
|
12 |
+
vertical-align: top;
|
13 |
+
white-space: pre-wrap;
|
14 |
+
max-width: 300px;
|
15 |
+
max-height: 100px;
|
16 |
+
overflow-y: auto;
|
17 |
+
}
|
18 |
+
th {
|
19 |
+
background-color: #f2f2f2;
|
20 |
+
}
|
21 |
+
</style>
|
22 |
+
<table border=""1"" class=""dataframe"">
|
23 |
+
<thead>
|
24 |
+
<tr style=""text-align: right;"">
|
25 |
+
<th>Column1</th>
|
26 |
+
<th>Column2</th>
|
27 |
+
</tr>
|
28 |
+
</thead>
|
29 |
+
<tbody>
|
30 |
+
<tr>
|
31 |
+
<td>Line 1\nLine 2\nLine 3</td>
|
32 |
+
<td>Short text</td>
|
33 |
+
</tr>
|
34 |
+
<tr>
|
35 |
+
<td>Single line</td>
|
36 |
+
<td>Very long text that goes on and on and might need scrolling in the cell</td>
|
37 |
+
</tr>
|
38 |
+
</tbody>
|
39 |
+
</table>",,,2024-07-29 15:18:50.387842
|
images/flowchart_graphrag.png
ADDED
![]() |
images/flowchart_graphrag_dark.png
ADDED
![]() |
images/flowchart_graphrag_final.png
ADDED
![]() |
images/graph_png.png
ADDED
![]() |
ki_gen/data_processor.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain_groq import ChatGroq
|
8 |
+
from langgraph.graph import StateGraph
|
9 |
+
from llmlingua import PromptCompressor
|
10 |
+
|
11 |
+
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
# compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
|
17 |
+
|
18 |
+
## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
|
19 |
+
## Before that, you need to pip install optimum auto-gptq
|
20 |
+
# llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# Requires ~2GB of RAM
|
25 |
+
def get_llm_lingua(compress_method:str = "llm_lingua2"):
|
26 |
+
|
27 |
+
# Requires ~2GB memory
|
28 |
+
if compress_method == "llm_lingua2":
|
29 |
+
llm_lingua2 = PromptCompressor(
|
30 |
+
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
|
31 |
+
use_llmlingua2=True,
|
32 |
+
device_map="cpu"
|
33 |
+
)
|
34 |
+
return llm_lingua2
|
35 |
+
|
36 |
+
# Requires ~8GB memory
|
37 |
+
elif compress_method == "llm_lingua":
|
38 |
+
llm_lingua = PromptCompressor(
|
39 |
+
model_name="microsoft/phi-2",
|
40 |
+
device_map="cpu"
|
41 |
+
)
|
42 |
+
return llm_lingua
|
43 |
+
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def compress(state: DocProcessorState, config: ConfigSchema):
|
48 |
+
"""
|
49 |
+
This node compresses last processing result for each doc using llm_lingua
|
50 |
+
"""
|
51 |
+
doc_process_histories = state["docs_in_processing"]
|
52 |
+
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
|
53 |
+
for doc_process_history in doc_process_histories:
|
54 |
+
doc_process_history.append(llm_lingua.compress_prompt(
|
55 |
+
doc = str(doc_process_history[-1]),
|
56 |
+
rate=config["configurable"].get("compress_rate") or 0.33,
|
57 |
+
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
|
58 |
+
)["compressed_prompt"]
|
59 |
+
)
|
60 |
+
|
61 |
+
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
62 |
+
|
63 |
+
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
|
64 |
+
"""
|
65 |
+
This node summarizes all docs in state["valid_docs"]
|
66 |
+
"""
|
67 |
+
|
68 |
+
prompt = """You are a 3GPP standardization expert.
|
69 |
+
Summarize the provided document in simple technical English for other experts in the field.
|
70 |
+
|
71 |
+
Document:
|
72 |
+
{document}"""
|
73 |
+
sysmsg = ChatPromptTemplate.from_messages([
|
74 |
+
("system", prompt)
|
75 |
+
])
|
76 |
+
model = config["configurable"].get("summarize_model") or "mixtral-8x7b-32768"
|
77 |
+
doc_process_histories = state["docs_in_processing"]
|
78 |
+
if model == "gpt-4o":
|
79 |
+
llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
80 |
+
else:
|
81 |
+
llm_summarize = ChatGroq(model=model)
|
82 |
+
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
|
83 |
+
|
84 |
+
for doc_process_history in doc_process_histories:
|
85 |
+
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
|
86 |
+
|
87 |
+
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
|
88 |
+
|
89 |
+
def custom_process(state: DocProcessorState):
|
90 |
+
"""
|
91 |
+
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
|
92 |
+
processing_model : the LLM which will perform the processing
|
93 |
+
context : the previous processing results to send as context to the LLM
|
94 |
+
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
|
95 |
+
"""
|
96 |
+
|
97 |
+
processing_params = state["process_steps"][state["current_process_step"]]
|
98 |
+
model = processing_params.get("processing_model") or "mixtral-8x7b-32768"
|
99 |
+
user_prompt = processing_params["prompt"]
|
100 |
+
context = processing_params.get("context") or [0]
|
101 |
+
doc_process_histories = state["docs_in_processing"]
|
102 |
+
if not isinstance(context, list):
|
103 |
+
context = [context]
|
104 |
+
|
105 |
+
processing_chain = get_model(model=model) | StrOutputParser()
|
106 |
+
|
107 |
+
for doc_process_history in doc_process_histories:
|
108 |
+
context_str = ""
|
109 |
+
for i, context_element in enumerate(context):
|
110 |
+
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
|
111 |
+
doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
|
112 |
+
|
113 |
+
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
114 |
+
|
115 |
+
def final(state: DocProcessorState):
|
116 |
+
"""
|
117 |
+
A node to store the final results of processing in the 'valid_docs' field
|
118 |
+
"""
|
119 |
+
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
|
120 |
+
|
121 |
+
# TODO : remove this node and use conditional entry point instead
|
122 |
+
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
|
123 |
+
"""
|
124 |
+
Dummy node
|
125 |
+
"""
|
126 |
+
# if not process_steps:
|
127 |
+
# process_steps = eval(input("Enter processing steps: "))
|
128 |
+
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
|
129 |
+
|
130 |
+
|
131 |
+
def next_processor_step(state: DocProcessorState):
|
132 |
+
"""
|
133 |
+
Conditional edge function to go to next processing step
|
134 |
+
"""
|
135 |
+
process_steps = state["process_steps"]
|
136 |
+
if state["current_process_step"] < len(process_steps):
|
137 |
+
step = process_steps[state["current_process_step"]]
|
138 |
+
if isinstance(step, dict):
|
139 |
+
step = "custom"
|
140 |
+
else:
|
141 |
+
step = "final"
|
142 |
+
|
143 |
+
return step
|
144 |
+
|
145 |
+
|
146 |
+
def build_data_processor_graph(memory):
|
147 |
+
"""
|
148 |
+
Builds the data processor graph
|
149 |
+
"""
|
150 |
+
|
151 |
+
graph_builder_doc_processor = StateGraph(DocProcessorState)
|
152 |
+
|
153 |
+
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
|
154 |
+
graph_builder_doc_processor.add_node("summarize", summarize_docs)
|
155 |
+
graph_builder_doc_processor.add_node("compress", compress)
|
156 |
+
graph_builder_doc_processor.add_node("custom", custom_process)
|
157 |
+
graph_builder_doc_processor.add_node("final", final)
|
158 |
+
|
159 |
+
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
|
160 |
+
graph_builder_doc_processor.add_conditional_edges(
|
161 |
+
"get_process_steps",
|
162 |
+
next_processor_step,
|
163 |
+
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
|
164 |
+
)
|
165 |
+
graph_builder_doc_processor.add_conditional_edges(
|
166 |
+
"summarize",
|
167 |
+
next_processor_step,
|
168 |
+
{"compress" : "compress", "final": "final", "custom" : "custom"}
|
169 |
+
)
|
170 |
+
graph_builder_doc_processor.add_conditional_edges(
|
171 |
+
"compress",
|
172 |
+
next_processor_step,
|
173 |
+
{"summarize" : "summarize", "final": "final", "custom" : "custom"}
|
174 |
+
)
|
175 |
+
graph_builder_doc_processor.add_conditional_edges(
|
176 |
+
"custom",
|
177 |
+
next_processor_step,
|
178 |
+
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
|
179 |
+
)
|
180 |
+
graph_builder_doc_processor.add_edge("final", "__end__")
|
181 |
+
|
182 |
+
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
|
183 |
+
return graph_doc_processor
|
ki_gen/data_retriever.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import re
|
5 |
+
from random import shuffle, sample
|
6 |
+
|
7 |
+
from langchain_groq import ChatGroq
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langchain_core.messages import HumanMessage
|
10 |
+
from langchain_community.graphs import Neo4jGraph
|
11 |
+
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
12 |
+
from langchain_core.output_parsers import StrOutputParser
|
13 |
+
from langchain_core.prompts import ChatPromptTemplate
|
14 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
15 |
+
from langchain_groq import ChatGroq
|
16 |
+
|
17 |
+
from langgraph.graph import StateGraph
|
18 |
+
|
19 |
+
from llmlingua import PromptCompressor
|
20 |
+
|
21 |
+
from ki_gen.prompts import (
|
22 |
+
CYPHER_GENERATION_PROMPT,
|
23 |
+
CONCEPT_SELECTION_PROMPT,
|
24 |
+
BINARY_GRADER_PROMPT,
|
25 |
+
SCORE_GRADER_PROMPT,
|
26 |
+
RELEVANT_CONCEPTS_PROMPT,
|
27 |
+
)
|
28 |
+
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def extract_cypher(text: str) -> str:
|
34 |
+
"""Extract Cypher code from a text.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
text: Text to extract Cypher code from.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Cypher code extracted from the text.
|
41 |
+
"""
|
42 |
+
# The pattern to find Cypher code enclosed in triple backticks
|
43 |
+
pattern_1 = r"```cypher\n(.*?)```"
|
44 |
+
pattern_2 = r"```\n(.*?)```"
|
45 |
+
|
46 |
+
# Find all matches in the input text
|
47 |
+
matches_1 = re.findall(pattern_1, text, re.DOTALL)
|
48 |
+
matches_2 = re.findall(pattern_2, text, re.DOTALL)
|
49 |
+
return [
|
50 |
+
matches_1[0] if matches_1 else text,
|
51 |
+
matches_2[0] if matches_2 else text,
|
52 |
+
text
|
53 |
+
]
|
54 |
+
|
55 |
+
def get_cypher_gen_chain(model: str = "openai"):
|
56 |
+
"""
|
57 |
+
Returns cypher gen chain using specified model for generation
|
58 |
+
This is used when the 'auto' cypher generation method has been configured
|
59 |
+
"""
|
60 |
+
|
61 |
+
if model=="openai":
|
62 |
+
llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
63 |
+
else:
|
64 |
+
llm_cypher_gen = ChatGroq(model = "mixtral-8x7b-32768")
|
65 |
+
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
|
66 |
+
return cypher_gen_chain
|
67 |
+
|
68 |
+
def get_concept_selection_chain(model: str = "openai"):
|
69 |
+
"""
|
70 |
+
Returns a chain to select the most relevant topic using specified model for generation.
|
71 |
+
This is used when the 'guided' cypher generation method has been configured
|
72 |
+
"""
|
73 |
+
|
74 |
+
if model == "openai":
|
75 |
+
llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
76 |
+
else:
|
77 |
+
llm_topic_selection = ChatGroq(model="llama3-70b-8192", max_tokens=8192)
|
78 |
+
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
|
79 |
+
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
|
80 |
+
return topic_selection_chain
|
81 |
+
|
82 |
+
def get_concepts(graph: Neo4jGraph):
|
83 |
+
concept_cypher = "MATCH (c:Concept) return c"
|
84 |
+
if isinstance(graph, Neo4jGraph):
|
85 |
+
concepts = graph.query(concept_cypher)
|
86 |
+
else:
|
87 |
+
user_input = input("Topics : ")
|
88 |
+
concepts = eval(user_input)
|
89 |
+
|
90 |
+
concepts_name = [concept['c']['name'] for concept in concepts]
|
91 |
+
return concepts_name
|
92 |
+
|
93 |
+
def get_related_concepts(graph: Neo4jGraph, question: str):
|
94 |
+
concepts = get_concepts(graph)
|
95 |
+
llm = get_model(model='gpt-4o')
|
96 |
+
print(f"this is the llm variable : {llm}")
|
97 |
+
def parse_answer(llm_answer : str):
|
98 |
+
print(f"This the llm_answer : {llm_answer}")
|
99 |
+
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
|
100 |
+
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
|
101 |
+
|
102 |
+
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
103 |
+
|
104 |
+
# We clean up the list we received from the LLM in case there were some hallucinations
|
105 |
+
related_concepts_cleaned = []
|
106 |
+
for related_concept in related_concepts_raw:
|
107 |
+
# If the concept returned from the LLM is in the list we keep it
|
108 |
+
if related_concept in concepts:
|
109 |
+
related_concepts_cleaned.append(related_concept)
|
110 |
+
else:
|
111 |
+
# The LLM sometimes only forgets a few words from the concept name
|
112 |
+
# We check if the generated concept is a substring of an existing one and if it is the case add it to the list
|
113 |
+
for concept in concepts:
|
114 |
+
if related_concept in concept:
|
115 |
+
related_concepts_cleaned.append(concept)
|
116 |
+
break
|
117 |
+
|
118 |
+
# TODO : Add concepts found via similarity search
|
119 |
+
return related_concepts_cleaned
|
120 |
+
|
121 |
+
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
|
122 |
+
concept_string = ""
|
123 |
+
for concept in concept_list:
|
124 |
+
concept_description_query = f"""
|
125 |
+
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
|
126 |
+
"""
|
127 |
+
concept_description = graph.query(concept_description_query)[0]['c.description']
|
128 |
+
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
|
129 |
+
return concept_string
|
130 |
+
|
131 |
+
def get_global_concepts(graph: Neo4jGraph):
|
132 |
+
concept_cypher = "MATCH (gc:GlobalConcept) return gc"
|
133 |
+
if isinstance(graph, Neo4jGraph):
|
134 |
+
concepts = graph.query(concept_cypher)
|
135 |
+
else:
|
136 |
+
user_input = input("Topics : ")
|
137 |
+
concepts = eval(user_input)
|
138 |
+
|
139 |
+
concepts_name = [concept['gc']['name'] for concept in concepts]
|
140 |
+
return concepts_name
|
141 |
+
|
142 |
+
def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
143 |
+
"""
|
144 |
+
The node where the cypher is generated
|
145 |
+
"""
|
146 |
+
|
147 |
+
graph = config["configurable"].get("graph")
|
148 |
+
question = state['query']
|
149 |
+
related_concepts = get_related_concepts(graph, question)
|
150 |
+
cyphers = []
|
151 |
+
|
152 |
+
if config["configurable"].get("cypher_gen_method") == 'auto':
|
153 |
+
cypher_gen_chain = get_cypher_gen_chain()
|
154 |
+
cyphers = cypher_gen_chain.invoke({
|
155 |
+
"schema": graph.schema,
|
156 |
+
"question": question,
|
157 |
+
"concepts": related_concepts
|
158 |
+
})
|
159 |
+
|
160 |
+
|
161 |
+
if config["configurable"].get("cypher_gen_method") == 'guided':
|
162 |
+
concept_selection_chain = get_concept_selection_chain()
|
163 |
+
print(f"Concept selection chain is : {concept_selection_chain}")
|
164 |
+
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
165 |
+
print(f"Selected topic are : {selected_topic}")
|
166 |
+
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
|
167 |
+
print(f"Cyphers are : {cyphers}")
|
168 |
+
|
169 |
+
|
170 |
+
if config["configurable"].get("validate_cypher"):
|
171 |
+
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
|
172 |
+
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
173 |
+
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
174 |
+
|
175 |
+
return {"cyphers" : cyphers}
|
176 |
+
|
177 |
+
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
178 |
+
"""
|
179 |
+
Helper function used when the 'guided' cypher generation method has been configured
|
180 |
+
"""
|
181 |
+
|
182 |
+
print(f"L.176 PLAN STEP : {plan_step}")
|
183 |
+
cypher_el = "(n) return n.title, n.description"
|
184 |
+
match plan_step:
|
185 |
+
case 0:
|
186 |
+
cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description"
|
187 |
+
case 1:
|
188 |
+
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
|
189 |
+
case 2:
|
190 |
+
cypher_el = "(ki:KeyIssue) RETURN ki.description"
|
191 |
+
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
|
192 |
+
|
193 |
+
def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
194 |
+
"""
|
195 |
+
This node retrieves docs from the graph using the generated cypher
|
196 |
+
"""
|
197 |
+
graph = config["configurable"].get("graph")
|
198 |
+
output = []
|
199 |
+
if graph is not None:
|
200 |
+
for cypher in state["cyphers"]:
|
201 |
+
try:
|
202 |
+
output = graph.query(cypher)
|
203 |
+
break
|
204 |
+
except Exception as e:
|
205 |
+
print("Failed to retrieve docs : {e}")
|
206 |
+
|
207 |
+
# Clean up the docs we received as there may be duplicates depending on the cypher query
|
208 |
+
all_docs = []
|
209 |
+
for doc in output:
|
210 |
+
unwinded_doc = {}
|
211 |
+
for key in doc:
|
212 |
+
if isinstance(doc[key], dict):
|
213 |
+
all_docs.append(doc[key])
|
214 |
+
else:
|
215 |
+
unwinded_doc.update({key: doc[key]})
|
216 |
+
if unwinded_doc:
|
217 |
+
all_docs.append(unwinded_doc)
|
218 |
+
|
219 |
+
|
220 |
+
filtered_docs = []
|
221 |
+
for doc in all_docs:
|
222 |
+
if doc not in filtered_docs:
|
223 |
+
filtered_docs.append(doc)
|
224 |
+
|
225 |
+
return {"docs": filtered_docs}
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
# Data model
|
232 |
+
class GradeDocumentsBinary(BaseModel):
|
233 |
+
"""Binary score for relevance check on retrieved documents."""
|
234 |
+
|
235 |
+
binary_score: str = Field(
|
236 |
+
description="Documents are relevant to the question, 'yes' or 'no'"
|
237 |
+
)
|
238 |
+
|
239 |
+
# LLM with function call
|
240 |
+
# llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
241 |
+
|
242 |
+
def get_binary_grader(model="mixtral-8x7b-32768"):
|
243 |
+
"""
|
244 |
+
Returns a binary grader to evaluate relevance of documents using specified model for generation
|
245 |
+
This is used when the 'binary' evaluation method has been configured
|
246 |
+
"""
|
247 |
+
|
248 |
+
|
249 |
+
if model == "gpt-4o":
|
250 |
+
llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
|
251 |
+
else:
|
252 |
+
llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
253 |
+
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
|
254 |
+
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
|
255 |
+
return retrieval_grader_binary
|
256 |
+
|
257 |
+
|
258 |
+
class GradeDocumentsScore(BaseModel):
|
259 |
+
"""Score for relevance check on retrieved documents."""
|
260 |
+
|
261 |
+
score: float = Field(
|
262 |
+
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
|
263 |
+
)
|
264 |
+
|
265 |
+
def get_score_grader(model="mixtral-8x7b-32768"):
|
266 |
+
"""
|
267 |
+
Returns a score grader to evaluate relevance of documents using specified model for generation
|
268 |
+
This is used when the 'score' evaluation method has been configured
|
269 |
+
"""
|
270 |
+
if model == "gpt-4o":
|
271 |
+
llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
|
272 |
+
else:
|
273 |
+
llm_grader_score = ChatGroq(model="mixtral-8x7b-32768", temperature = 0)
|
274 |
+
structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
|
275 |
+
retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
|
276 |
+
return retrieval_grader_score
|
277 |
+
|
278 |
+
|
279 |
+
def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="mixtral-8x7b-32768"):
|
280 |
+
'''
|
281 |
+
doc : the document to evaluate
|
282 |
+
query : the query to which to doc shoud be relevant
|
283 |
+
method : "binary" or "score"
|
284 |
+
threshold : for "score" method, score above which a doc is considered relevant
|
285 |
+
'''
|
286 |
+
if method == "binary":
|
287 |
+
retrieval_grader_binary = get_binary_grader(model=eval_model)
|
288 |
+
return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
|
289 |
+
elif method == "score":
|
290 |
+
retrieval_grader_score = get_score_grader(model=eval_model)
|
291 |
+
score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
|
292 |
+
if score is not None:
|
293 |
+
return score if score >= threshold else 0
|
294 |
+
else:
|
295 |
+
# Couldn't parse score, marking document as relevant by default
|
296 |
+
return 1
|
297 |
+
else:
|
298 |
+
raise ValueError("Invalid method")
|
299 |
+
|
300 |
+
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
301 |
+
"""
|
302 |
+
This node performs evaluation of the retrieved docs and
|
303 |
+
"""
|
304 |
+
|
305 |
+
eval_method = config["configurable"].get("eval_method") or "binary"
|
306 |
+
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
307 |
+
valid_doc_scores = []
|
308 |
+
|
309 |
+
for doc in sample(state["docs"], min(25, len(state["docs"]))):
|
310 |
+
score = eval_doc(
|
311 |
+
doc=format_doc(doc),
|
312 |
+
query=state["query"],
|
313 |
+
method=eval_method,
|
314 |
+
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
315 |
+
eval_model = config["configurable"].get("eval_model") or "mixtral-8x7b-32768"
|
316 |
+
)
|
317 |
+
if score:
|
318 |
+
valid_doc_scores.append((doc, score))
|
319 |
+
|
320 |
+
if eval_method == 'score':
|
321 |
+
# Get at most MAX_DOCS items with the highest score if score method was used
|
322 |
+
valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
|
323 |
+
valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
|
324 |
+
else:
|
325 |
+
# Get at mots MAX_DOCS items at random if binary method was used
|
326 |
+
shuffle(valid_doc_scores)
|
327 |
+
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
|
328 |
+
|
329 |
+
return {"valid_docs": valid_docs + (state["valid_docs"] or [])}
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
def build_data_retriever_graph(memory):
|
334 |
+
"""
|
335 |
+
Builds the data_retriever graph
|
336 |
+
"""
|
337 |
+
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
|
338 |
+
|
339 |
+
graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher)
|
340 |
+
graph_builder_doc_retriever.add_node("get_docs", get_docs)
|
341 |
+
graph_builder_doc_retriever.add_node("eval_docs", eval_docs)
|
342 |
+
|
343 |
+
|
344 |
+
graph_builder_doc_retriever.add_edge("__start__", "generate_cypher")
|
345 |
+
graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs")
|
346 |
+
graph_builder_doc_retriever.add_edge("get_docs", "eval_docs")
|
347 |
+
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
|
348 |
+
|
349 |
+
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
|
350 |
+
|
351 |
+
return graph_doc_retriever
|
ki_gen/planner.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
from typing import Annotated
|
5 |
+
from typing_extensions import TypedDict
|
6 |
+
|
7 |
+
from langchain_groq import ChatGroq
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
+
from langchain_community.graphs import Neo4jGraph
|
11 |
+
|
12 |
+
from langgraph.graph import StateGraph
|
13 |
+
from langgraph.graph import add_messages
|
14 |
+
|
15 |
+
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
|
16 |
+
from ki_gen.data_retriever import build_data_retriever_graph
|
17 |
+
from ki_gen.data_processor import build_data_processor_graph
|
18 |
+
from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState
|
19 |
+
|
20 |
+
|
21 |
+
##########################################################################
|
22 |
+
###### NODES DEFINITION ######
|
23 |
+
##########################################################################
|
24 |
+
|
25 |
+
def validate_node(state: State):
|
26 |
+
"""
|
27 |
+
This node inserts the plan validation prompt.
|
28 |
+
"""
|
29 |
+
prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise.
|
30 |
+
If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct"."""
|
31 |
+
output = HumanMessage(content=prompt)
|
32 |
+
return {"messages" : [output]}
|
33 |
+
|
34 |
+
|
35 |
+
# Wrappers to call LLMs on the state messsages field
|
36 |
+
def chatbot_llama(state: State):
|
37 |
+
llm_llama = ChatGroq(model="llama3-70b-8192")
|
38 |
+
return {"messages" : [llm_llama.invoke(state["messages"])]}
|
39 |
+
|
40 |
+
def chatbot_mixtral(state: State):
|
41 |
+
llm_mixtral = ChatGroq(model="mixtral-8x7b-32768")
|
42 |
+
return {"messages" : [llm_mixtral.invoke(state["messages"])]}
|
43 |
+
|
44 |
+
def chatbot_openai(state: State):
|
45 |
+
llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
46 |
+
return {"messages" : [llm_openai.invoke(state["messages"])]}
|
47 |
+
|
48 |
+
chatbots = {"gpt-4o" : chatbot_openai,
|
49 |
+
"mixtral-8x7b-32768" : chatbot_mixtral,
|
50 |
+
"llama3-70b-8192" : chatbot_llama
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
def parse_plan(state: State):
|
55 |
+
"""
|
56 |
+
This node parses the generated plan and writes in the 'store_plan' field of the state
|
57 |
+
"""
|
58 |
+
plan = state["messages"][-3].content
|
59 |
+
store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:]
|
60 |
+
try:
|
61 |
+
store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0]
|
62 |
+
except Exception as e:
|
63 |
+
print(f"Error while removing <END_OF_PLAN> : {e}")
|
64 |
+
|
65 |
+
return {"store_plan" : store_plan}
|
66 |
+
|
67 |
+
def detail_step(state: State, config: ConfigSchema):
|
68 |
+
"""
|
69 |
+
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
|
70 |
+
"""
|
71 |
+
print("test")
|
72 |
+
print(state)
|
73 |
+
|
74 |
+
if 'current_plan_step' in state.keys():
|
75 |
+
print("all good chief")
|
76 |
+
else:
|
77 |
+
state["current_plan_step"] = None
|
78 |
+
|
79 |
+
current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end
|
80 |
+
if config["configurable"].get("use_detailed_query"):
|
81 |
+
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
|
82 |
+
Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""")
|
83 |
+
query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm"))
|
84 |
+
return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query}
|
85 |
+
|
86 |
+
return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []}
|
87 |
+
|
88 |
+
def get_detailed_query(context : list, model : str = "mixtral-8x7b-32768"):
|
89 |
+
"""
|
90 |
+
Simple helper function for the detail_step node
|
91 |
+
"""
|
92 |
+
if model == 'gpt-4o':
|
93 |
+
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
94 |
+
else:
|
95 |
+
llm = ChatGroq(model=model)
|
96 |
+
return llm.invoke(context)
|
97 |
+
|
98 |
+
def concatenate_data(state: State):
|
99 |
+
"""
|
100 |
+
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
|
101 |
+
"""
|
102 |
+
prompt = f"""#########TECHNICAL INFORMATION ############
|
103 |
+
{str(state["valid_docs"])}
|
104 |
+
|
105 |
+
########END OF TECHNICAL INFORMATION#######
|
106 |
+
|
107 |
+
Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan :
|
108 |
+
{state['store_plan'][state['current_plan_step']]}
|
109 |
+
"""
|
110 |
+
|
111 |
+
return {"messages": [HumanMessage(content=prompt)]}
|
112 |
+
|
113 |
+
|
114 |
+
def human_validation(state: HumanValidationState) -> HumanValidationState:
|
115 |
+
"""
|
116 |
+
Dummy node to interrupt before
|
117 |
+
"""
|
118 |
+
return {'process_steps' : []}
|
119 |
+
|
120 |
+
def generate_ki(state: State):
|
121 |
+
"""
|
122 |
+
This node inserts the prompt to begin Key Issues generation
|
123 |
+
"""
|
124 |
+
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}")
|
125 |
+
|
126 |
+
prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
|
127 |
+
{state['store_plan'][state['current_plan_step'] + 1]}"""
|
128 |
+
|
129 |
+
return {"messages" : [HumanMessage(content=prompt)]}
|
130 |
+
|
131 |
+
def detail_ki(state: State):
|
132 |
+
"""
|
133 |
+
This node inserts the last prompt to detail the generated Key Issues
|
134 |
+
"""
|
135 |
+
prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
|
136 |
+
{state['store_plan'][state['current_plan_step'] + 2]}"""
|
137 |
+
|
138 |
+
return {"messages" : [HumanMessage(content=prompt)]}
|
139 |
+
|
140 |
+
##########################################################################
|
141 |
+
###### CONDITIONAL EDGE FUNCTIONS ######
|
142 |
+
##########################################################################
|
143 |
+
|
144 |
+
def validate_plan(state: State):
|
145 |
+
"""
|
146 |
+
Whether to regenerate the plan or to parse it
|
147 |
+
"""
|
148 |
+
if "messages" in state and state["messages"][-1].content in ["My plan is correct.","My plan is correct"]:
|
149 |
+
return "parse"
|
150 |
+
return "validate"
|
151 |
+
|
152 |
+
def next_plan_step(state: State, config: ConfigSchema):
|
153 |
+
"""
|
154 |
+
Proceed to next plan step (either generate KI or retrieve more data)
|
155 |
+
"""
|
156 |
+
if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"):
|
157 |
+
return "generate_key_issues"
|
158 |
+
if state["current_plan_step"] == len(state["store_plan"]) - 1:
|
159 |
+
return "generate_key_issues"
|
160 |
+
else:
|
161 |
+
return "detail_step"
|
162 |
+
|
163 |
+
def detail_or_data_retriever(state: State, config: ConfigSchema):
|
164 |
+
"""
|
165 |
+
Detail the query to use for data retrieval or not
|
166 |
+
"""
|
167 |
+
if config["configurable"].get("use_detailed_query"):
|
168 |
+
return "chatbot_detail"
|
169 |
+
else:
|
170 |
+
return "data_retriever"
|
171 |
+
|
172 |
+
def retrieve_or_process(state: State):
|
173 |
+
"""
|
174 |
+
Process the retrieved docs or keep retrieving
|
175 |
+
"""
|
176 |
+
if state['human_validated']:
|
177 |
+
return "process"
|
178 |
+
return "retrieve"
|
179 |
+
# while True:
|
180 |
+
# user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ")
|
181 |
+
# if user_input.lower() == "y":
|
182 |
+
# return "retrieve"
|
183 |
+
# if not user_input or user_input.lower() == "n":
|
184 |
+
# return "process"
|
185 |
+
# print("Please answer with 'y' or 'n'.\n")
|
186 |
+
|
187 |
+
|
188 |
+
def build_planner_graph(memory, config):
|
189 |
+
"""
|
190 |
+
Builds the planner graph
|
191 |
+
"""
|
192 |
+
graph_builder = StateGraph(State)
|
193 |
+
|
194 |
+
graph_doc_retriever = build_data_retriever_graph(memory)
|
195 |
+
graph_doc_processor = build_data_processor_graph(memory)
|
196 |
+
graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]])
|
197 |
+
graph_builder.add_node("validate", validate_node)
|
198 |
+
graph_builder.add_node("chatbot_detail", chatbot_llama)
|
199 |
+
graph_builder.add_node("parse", parse_plan)
|
200 |
+
graph_builder.add_node("detail_step", detail_step)
|
201 |
+
graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState)
|
202 |
+
graph_builder.add_node("human_validation", human_validation)
|
203 |
+
graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState)
|
204 |
+
graph_builder.add_node("concatenate_data", concatenate_data)
|
205 |
+
graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]])
|
206 |
+
graph_builder.add_node("generate_ki", generate_ki)
|
207 |
+
graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]])
|
208 |
+
graph_builder.add_node("detail_ki", detail_ki)
|
209 |
+
graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]])
|
210 |
+
|
211 |
+
graph_builder.add_edge("validate", "chatbot_planner")
|
212 |
+
graph_builder.add_edge("parse", "detail_step")
|
213 |
+
|
214 |
+
|
215 |
+
# graph_builder.add_edge("detail_step", "chatbot2")
|
216 |
+
graph_builder.add_edge("chatbot_detail", "data_retriever")
|
217 |
+
graph_builder.add_edge("data_retriever", "human_validation")
|
218 |
+
|
219 |
+
|
220 |
+
graph_builder.add_edge("data_processor", "concatenate_data")
|
221 |
+
graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
|
222 |
+
graph_builder.add_edge("generate_ki", "chatbot_ki")
|
223 |
+
graph_builder.add_edge("chatbot_ki", "detail_ki")
|
224 |
+
graph_builder.add_edge("detail_ki", "chatbot_final")
|
225 |
+
graph_builder.add_edge("chatbot_final", "__end__")
|
226 |
+
|
227 |
+
graph_builder.add_conditional_edges(
|
228 |
+
"detail_step",
|
229 |
+
detail_or_data_retriever,
|
230 |
+
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
|
231 |
+
)
|
232 |
+
graph_builder.add_conditional_edges(
|
233 |
+
"human_validation",
|
234 |
+
retrieve_or_process,
|
235 |
+
{"retrieve" : "data_retriever", "process" : "data_processor"}
|
236 |
+
)
|
237 |
+
graph_builder.add_conditional_edges(
|
238 |
+
"chatbot_planner",
|
239 |
+
validate_plan,
|
240 |
+
{"parse" : "parse", "validate": "validate"}
|
241 |
+
)
|
242 |
+
graph_builder.add_conditional_edges(
|
243 |
+
"chatbot_exec_step",
|
244 |
+
next_plan_step,
|
245 |
+
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
|
246 |
+
)
|
247 |
+
|
248 |
+
graph_builder.set_entry_point("chatbot_planner")
|
249 |
+
graph = graph_builder.compile(
|
250 |
+
checkpointer=memory,
|
251 |
+
interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"],
|
252 |
+
)
|
253 |
+
return graph
|
ki_gen/prompts.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts.prompt import PromptTemplate
|
2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
3 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
4 |
+
from ki_gen.utils import ConfigSchema
|
5 |
+
|
6 |
+
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
|
7 |
+
Instructions:
|
8 |
+
Use only the provided relationship types and properties in the schema.
|
9 |
+
Do not use any other relationship types or properties that are not provided.
|
10 |
+
Schema:
|
11 |
+
{schema}
|
12 |
+
|
13 |
+
|
14 |
+
Concepts:
|
15 |
+
{concepts}
|
16 |
+
|
17 |
+
|
18 |
+
Concept names can ONLY be selected from the above list
|
19 |
+
|
20 |
+
Note: Do not include any explanations or apologies in your responses.
|
21 |
+
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
22 |
+
Do not include any text except the generated Cypher statement.
|
23 |
+
|
24 |
+
The question is:
|
25 |
+
{question}"""
|
26 |
+
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
27 |
+
input_variables=["schema", "question", "concepts"], template=CYPHER_GENERATION_TEMPLATE
|
28 |
+
)
|
29 |
+
|
30 |
+
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
|
31 |
+
The information part contains the provided information that you must use to construct an answer.
|
32 |
+
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
33 |
+
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
|
34 |
+
Here is an example:
|
35 |
+
|
36 |
+
Question: Which managers own Neo4j stocks?
|
37 |
+
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
|
38 |
+
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
|
39 |
+
|
40 |
+
Follow this example when generating answers.
|
41 |
+
If the provided information is empty, say that you don't know the answer.
|
42 |
+
Information:
|
43 |
+
{context}
|
44 |
+
|
45 |
+
Question: {question}
|
46 |
+
Helpful Answer:"""
|
47 |
+
CYPHER_QA_PROMPT = PromptTemplate(
|
48 |
+
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
|
49 |
+
)
|
50 |
+
|
51 |
+
PLAN_GEN_PROMPT = """System : You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement
|
52 |
+
|
53 |
+
System : Let's first understand the problem and devise a plan to solve the problem.
|
54 |
+
Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
|
55 |
+
Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
|
56 |
+
At the end of your plan, say '<END_OF_PLAN>'"""
|
57 |
+
|
58 |
+
PLAN_MODIFICATION_PROMPT = """You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement.
|
59 |
+
To achieve this goal we are going to follow this generic plan :
|
60 |
+
|
61 |
+
###PLAN TEMPLATE###
|
62 |
+
|
63 |
+
Plan:
|
64 |
+
|
65 |
+
1. **Understanding the Problem**: Gather information from existing specifications and standards to thoroughly understand the technical requirement. This should help you understand the key aspects of the problem.
|
66 |
+
2. **Gather information about latest innovations** : Gather information about the latest innovations related to the problem by looking at the most relevant research papers.
|
67 |
+
3. **Researching current challenges** : Research the current challenges in this area by looking at the existing similar key issues that have been identified by 3GPP.
|
68 |
+
4. **Identifying NEW and INNOVATIVE Key Issues**: Based on the understanding of the problem and the current challenges, identify new and innovative key issues that could occur while trying to fulfill this requirement. These key issues should be relevant, significant, and not yet addressed by existing solutions.
|
69 |
+
5. **Develop Detailed Descriptions for Each Key Issue**: For each identified key issue, provide a detailed description, including the specific challenges and areas requiring further study.
|
70 |
+
|
71 |
+
<END_OF_PLAN>
|
72 |
+
|
73 |
+
###END OF PLAN TEMPLATE###
|
74 |
+
|
75 |
+
Let's and devise a plan to solve the problem by adapting the PLAN TEMPLATE.
|
76 |
+
Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
|
77 |
+
Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
|
78 |
+
At the end of your plan, say '<END_OF_PLAN>' """
|
79 |
+
|
80 |
+
CONCEPT_SELECTION_TEMPLATE = """Task: Select the most relevant topic to the user question
|
81 |
+
Instructions:
|
82 |
+
Select the most relevant Concept to the user's question.
|
83 |
+
Concepts can ONLY be selected from the list below.
|
84 |
+
|
85 |
+
Concepts:
|
86 |
+
{concepts}
|
87 |
+
|
88 |
+
Note: Do not include any explanations or apologies in your responses.
|
89 |
+
Do not include any text except the selected concept.
|
90 |
+
|
91 |
+
The question is:
|
92 |
+
{question}"""
|
93 |
+
CONCEPT_SELECTION_PROMPT = PromptTemplate(
|
94 |
+
input_variables=["concepts", "question"], template=CONCEPT_SELECTION_TEMPLATE
|
95 |
+
)
|
96 |
+
|
97 |
+
RELEVANT_CONCEPTS_TEMPLATE = """
|
98 |
+
## CONCEPTS ##
|
99 |
+
{concepts}
|
100 |
+
## END OF CONCEPTS ##
|
101 |
+
|
102 |
+
Select the 20 most relevant concepts to the user query.
|
103 |
+
Output your answer as a numbered list preceeded with the header 'Concepts:'.
|
104 |
+
|
105 |
+
User query :
|
106 |
+
{user_query}
|
107 |
+
"""
|
108 |
+
RELEVANT_CONCEPTS_PROMPT = ChatPromptTemplate.from_messages([
|
109 |
+
("human", RELEVANT_CONCEPTS_TEMPLATE)
|
110 |
+
])
|
111 |
+
|
112 |
+
SUMMARIZER_TEMPLATE = """You are a 3GPP standardization expert.
|
113 |
+
Summarize the provided document in simple technical English for other experts in the field.
|
114 |
+
|
115 |
+
Document:
|
116 |
+
{document}"""
|
117 |
+
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
|
118 |
+
("system", SUMMARIZER_TEMPLATE)
|
119 |
+
])
|
120 |
+
|
121 |
+
|
122 |
+
BINARY_GRADER_TEMPLATE = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
123 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
|
124 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
125 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
126 |
+
BINARY_GRADER_PROMPT = ChatPromptTemplate.from_messages(
|
127 |
+
[
|
128 |
+
("system", BINARY_GRADER_TEMPLATE),
|
129 |
+
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
130 |
+
]
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
SCORE_GRADER_TEMPLATE = """Grasp and understand both the query and the document before score generation.
|
135 |
+
Then, based on your understanding and analysis quantify the relevance between the document and the query.
|
136 |
+
Give the rationale before answering.
|
137 |
+
Ouput your answer as a score ranging between 0 (irrelevant document) and 1 (completely relevant document)"""
|
138 |
+
|
139 |
+
SCORE_GRADER_PROMPT = ChatPromptTemplate.from_messages(
|
140 |
+
[
|
141 |
+
("system", SCORE_GRADER_TEMPLATE),
|
142 |
+
("human", "Passage: \n\n {document} \n\n User query: {query}")
|
143 |
+
]
|
144 |
+
)
|
145 |
+
|
146 |
+
def get_initial_prompt(config: ConfigSchema, user_query : str):
|
147 |
+
if config["configurable"].get("plan_method") == "generation":
|
148 |
+
prompt = PLAN_GEN_PROMPT
|
149 |
+
elif config["configurable"].get("plan_method") == "modification":
|
150 |
+
prompt = PLAN_MODIFICATION_PROMPT
|
151 |
+
else:
|
152 |
+
raise ValueError("Incorrect plan_method, should be 'generation' or 'modification'")
|
153 |
+
|
154 |
+
user_input = user_query or input("User :")
|
155 |
+
return {"messages" : [SystemMessage(content=prompt), HumanMessage(content=user_input)]}
|
ki_gen/utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import getpass
|
3 |
+
import html
|
4 |
+
|
5 |
+
|
6 |
+
from typing import Annotated, Union
|
7 |
+
from typing_extensions import TypedDict
|
8 |
+
|
9 |
+
from langchain_community.graphs import Neo4jGraph
|
10 |
+
from langchain_groq import ChatGroq
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
|
13 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
14 |
+
from langgraph.checkpoint import base
|
15 |
+
from langgraph.graph import add_messages
|
16 |
+
|
17 |
+
with SqliteSaver.from_conn_string(":memory:") as mem :
|
18 |
+
memory = mem
|
19 |
+
|
20 |
+
|
21 |
+
def format_df(df):
|
22 |
+
"""
|
23 |
+
Used to display the generated plan in a nice format
|
24 |
+
Returns html code in a string
|
25 |
+
"""
|
26 |
+
def format_cell(cell):
|
27 |
+
if isinstance(cell, str):
|
28 |
+
# Encode special characters, but preserve line breaks
|
29 |
+
return html.escape(cell).replace('\n', '<br>')
|
30 |
+
return cell
|
31 |
+
# Convert the DataFrame to HTML with custom CSS
|
32 |
+
formatted_df = df.map(format_cell)
|
33 |
+
html_table = formatted_df.to_html(escape=False, index=False)
|
34 |
+
|
35 |
+
# Add custom CSS to allow multiple lines and scrolling in cells
|
36 |
+
css = """
|
37 |
+
<style>
|
38 |
+
table {
|
39 |
+
border-collapse: collapse;
|
40 |
+
width: 100%;
|
41 |
+
}
|
42 |
+
th, td {
|
43 |
+
border: 1px solid black;
|
44 |
+
padding: 8px;
|
45 |
+
text-align: left;
|
46 |
+
vertical-align: top;
|
47 |
+
white-space: pre-wrap;
|
48 |
+
max-width: 300px;
|
49 |
+
max-height: 100px;
|
50 |
+
overflow-y: auto;
|
51 |
+
}
|
52 |
+
th {
|
53 |
+
background-color: #f2f2f2;
|
54 |
+
}
|
55 |
+
</style>
|
56 |
+
"""
|
57 |
+
|
58 |
+
return css + html_table
|
59 |
+
|
60 |
+
def format_doc(doc: dict) -> str :
|
61 |
+
formatted_string = ""
|
62 |
+
for key in doc:
|
63 |
+
formatted_string += f"**{key}**: {doc[key]}\n"
|
64 |
+
return formatted_string
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def _set_env(var: str, value: str = None):
|
69 |
+
if not os.environ.get(var):
|
70 |
+
if value:
|
71 |
+
os.environ[var] = value
|
72 |
+
else:
|
73 |
+
os.environ[var] = getpass.getpass(f"{var}: ")
|
74 |
+
|
75 |
+
|
76 |
+
def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
|
77 |
+
"""
|
78 |
+
Initialize app with user api keys and sets up proxy settings
|
79 |
+
"""
|
80 |
+
_set_env("GROQ_API_KEY", value=groq_key)
|
81 |
+
_set_env("LANGSMITH_API_KEY", value=langsmith_key)
|
82 |
+
_set_env("OPENAI_API_KEY", value=openai_key)
|
83 |
+
os.environ["LANGSMITH_TRACING_V2"] = "true"
|
84 |
+
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
|
85 |
+
os.environ["http_proxy"] = "185.46.212.98:80"
|
86 |
+
os.environ["https_proxy"] = "185.46.212.98:80"
|
87 |
+
os.environ["NO_PROXY"] = "thalescloud.io"
|
88 |
+
|
89 |
+
def clear_memory(memory, thread_id: str) -> None:
|
90 |
+
"""
|
91 |
+
Clears checkpointer state for a given thread_id, broken for now
|
92 |
+
TODO : fix this
|
93 |
+
"""
|
94 |
+
with SqliteSaver.from_conn_string(":memory:") as mem :
|
95 |
+
memory = mem
|
96 |
+
checkpoint = base.empty_checkpoint()
|
97 |
+
memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
|
98 |
+
|
99 |
+
def get_model(model : str = "mixtral-8x7b-32768"):
|
100 |
+
"""
|
101 |
+
Wrapper to return the correct llm object depending on the 'model' param
|
102 |
+
"""
|
103 |
+
if model == "gpt-4o":
|
104 |
+
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
105 |
+
else:
|
106 |
+
llm = ChatGroq(model=model)
|
107 |
+
return llm
|
108 |
+
|
109 |
+
|
110 |
+
class ConfigSchema(TypedDict):
|
111 |
+
graph: Neo4jGraph
|
112 |
+
plan_method: str
|
113 |
+
use_detailed_query: bool
|
114 |
+
|
115 |
+
class State(TypedDict):
|
116 |
+
messages : Annotated[list, add_messages]
|
117 |
+
store_plan : list[str]
|
118 |
+
current_plan_step : int
|
119 |
+
valid_docs : list[str]
|
120 |
+
|
121 |
+
class DocRetrieverState(TypedDict):
|
122 |
+
messages: Annotated[list, add_messages]
|
123 |
+
query: str
|
124 |
+
docs: list[dict]
|
125 |
+
cyphers: list[str]
|
126 |
+
current_plan_step : int
|
127 |
+
valid_docs: list[Union[str, dict]]
|
128 |
+
|
129 |
+
class HumanValidationState(TypedDict):
|
130 |
+
human_validated : bool
|
131 |
+
process_steps : list[str]
|
132 |
+
|
133 |
+
def update_doc_history(left : list | None, right : list | None) -> list:
|
134 |
+
"""
|
135 |
+
Reducer for the 'docs_in_processing' field.
|
136 |
+
Doesn't work currently because of bad handlinf of duplicates
|
137 |
+
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
|
138 |
+
"""
|
139 |
+
if not left:
|
140 |
+
# This shouldn't happen
|
141 |
+
left = [[]]
|
142 |
+
if not right:
|
143 |
+
right = []
|
144 |
+
|
145 |
+
for i in range(len(right)):
|
146 |
+
left[i].append(right[i])
|
147 |
+
return left
|
148 |
+
|
149 |
+
|
150 |
+
class DocProcessorState(TypedDict):
|
151 |
+
valid_docs : list[Union[str, dict]]
|
152 |
+
docs_in_processing : list
|
153 |
+
process_steps : list[Union[str,dict]]
|
154 |
+
current_process_step : int
|
155 |
+
|