heymenn commited on
Commit
6aaddef
·
verified ·
1 Parent(s): 228c70f

Upload 15 files

Browse files
app.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_community.graphs import Neo4jGraph
3
+ import pandas as pd
4
+ import json
5
+
6
+ from ki_gen.planner import build_planner_graph
7
+ from ki_gen.utils import clear_memory, init_app, format_df, memory
8
+ from ki_gen.prompts import get_initial_prompt
9
+
10
+ MAX_PROCESSING_STEPS = 10
11
+
12
+ print(f"WHEREVER YOU ARE THIS IS THE MEMORY INSTANCE !!!! : {type(memory)} !!!!")
13
+
14
+ def start_inference(data):
15
+ """
16
+ Starts plan generation with user_query as input which gets displayed after
17
+ """
18
+ config = data[config_state]
19
+ init_app(
20
+ openai_key=data[openai_api_key],
21
+ groq_key=data[groq_api_key],
22
+ langsmith_key=data[langsmith_api_key]
23
+ )
24
+
25
+ clear_memory(memory, config["configurable"].get("thread_id"))
26
+
27
+ graph = build_planner_graph(memory, config["configurable"])
28
+ with open("images/graph_png.png", "wb") as f:
29
+ f.write(graph.get_graph(xray=1).draw_mermaid_png())
30
+
31
+ print("here !")
32
+ for event in graph.stream(get_initial_prompt(config, data[user_query]), config, stream_mode="values"):
33
+ if "messages" in event:
34
+ event["messages"][-1].pretty_print()
35
+
36
+ state = graph.get_state(config)
37
+ steps = [i for i in range(1,len(state.values['store_plan'])+1)]
38
+ df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']})
39
+ return [df, graph]
40
+
41
+ def update_display(df):
42
+ """
43
+ Displays the df after it has been generated
44
+ """
45
+ formatted_html = format_df(df)
46
+ return {
47
+ plan_display : gr.update(visible=True, value = formatted_html),
48
+ select_step_to_modify : gr.update(visible=True, value=0),
49
+ enter_new_step : gr.update(visible=True),
50
+ submit_new_step : gr.update(visible=True),
51
+ continue_inference_btn : gr.update(visible=True)
52
+ }
53
+
54
+ def format_docs(docs: list[dict]):
55
+ formatted_results = ""
56
+ for i, doc in enumerate(docs):
57
+ formatted_results += f"\n### Document {i}\n"
58
+ for key in doc:
59
+ formatted_results += f"**{key}**: {doc[key]}\n"
60
+ return formatted_results
61
+
62
+ def continue_inference(data):
63
+ """
64
+ Proceeds to next plan step
65
+ """
66
+ graph = data[graph_state]
67
+ config = data[config_state]
68
+
69
+ for event in graph.stream(None, config, stream_mode="values"):
70
+ if "messages" in event:
71
+ event["messages"][-1].pretty_print()
72
+
73
+ snapshot = graph.get_state(config)
74
+ print(f"DEBUG INFO : next : {snapshot.next}")
75
+ print(f"DEBUG INFO ++ L.75: {snapshot}")
76
+
77
+ if snapshot.next and snapshot.next[0] == "human_validation":
78
+ return {
79
+ continue_inference_btn : gr.update(visible=False),
80
+ graph_state : graph,
81
+ retrieve_more_docs_btn : gr.update(visible=True),
82
+ continue_to_processing_btn : gr.update(visible=True),
83
+ human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"),
84
+ retrieved_docs_state : snapshot.values['valid_docs']
85
+ }
86
+
87
+ return {
88
+ plan_result : snapshot.values["messages"][-1].content,
89
+ graph_state : graph,
90
+ continue_inference_btn : gr.update(visible=False)
91
+ }
92
+
93
+ def continue_to_processing():
94
+ """
95
+ Continue to doc processing configuration
96
+ """
97
+ return {
98
+ retrieve_more_docs_btn : gr.update(visible=False),
99
+ continue_to_processing_btn : gr.update(visible=False),
100
+ human_validation_title : gr.update(visible=False),
101
+ process_data_btn : gr.update(visible=True),
102
+ process_steps_nb : gr.update(visible=True),
103
+ process_steps_title : gr.update(visible=True)
104
+ }
105
+
106
+ def retrieve_more_docs(data):
107
+ """
108
+ Restart doc retrieval
109
+ For now we simply regenerate the cypher, it may be different because temperature != 0
110
+ """
111
+ graph = data[graph_state]
112
+ config = data[config_state]
113
+ graph.update_state(config, {'human_validated' : False}, as_node="human_validation")
114
+
115
+ for event in graph.stream(None, config, stream_mode="values"):
116
+ if "messages" in event:
117
+ event["messages"][-1].pretty_print()
118
+
119
+ snapshot = graph.get_state(config)
120
+ print(f"DEBUG INFO : next : {snapshot.next}")
121
+ print(f"DEBUG INFO ++ L.121: {snapshot}")
122
+
123
+ return {
124
+ graph_state : graph,
125
+ human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"),
126
+ retrieved_docs_display : format_docs(snapshot.values['valid_docs'])
127
+ }
128
+
129
+ def execute_processing(*args):
130
+ """
131
+ Execute doc processing
132
+ Args are passed as a list and not a dict for syntax convenience
133
+ """
134
+ graph = args[-2]
135
+ config = args[-1]
136
+ nb_process_steps = args[-3]
137
+
138
+ process_steps = []
139
+ for i in range (nb_process_steps):
140
+ if args[i] == "custom":
141
+ process_steps.append({"prompt" : args[nb_process_steps + i], "context" : args[2*nb_process_steps + i], "processing_model" : args[3*nb_process_steps + i]})
142
+ else:
143
+ process_steps.append(args[i])
144
+
145
+ graph.update_state(config, {'human_validated' : True, 'process_steps' : process_steps}, as_node="human_validation")
146
+
147
+ for event in graph.stream(None, config, stream_mode="values"):
148
+ if "messages" in event:
149
+ event["messages"][-1].pretty_print()
150
+
151
+ snapshot = graph.get_state(config)
152
+ print(f"DEBUG INFO : next : {snapshot.next}")
153
+ print(f"DEBUG INFO ++ L.153: {snapshot}")
154
+
155
+ return {
156
+ plan_result : snapshot.values["messages"][-1].content,
157
+ processed_docs_state : snapshot.values["valid_docs"],
158
+ graph_state : graph,
159
+ continue_inference_btn : gr.update(visible=True),
160
+ process_steps_nb : gr.update(value=0, visible=False),
161
+ process_steps_title : gr.update(visible=False),
162
+ process_data_btn : gr.update(visible=False),
163
+ }
164
+
165
+
166
+
167
+ def update_config_display():
168
+ """
169
+ Called after loading the config.json file
170
+ TODO : allow the user to specify a path to the config file
171
+ """
172
+ with open("config.json", "r") as config_file:
173
+ config = json.load(config_file)
174
+
175
+ return {
176
+ main_llm : config["main_llm"],
177
+ plan_method : config["plan_method"],
178
+ use_detailed_query : config["use_detailed_query"],
179
+ cypher_gen_method : config["cypher_gen_method"],
180
+ validate_cypher : config["validate_cypher"],
181
+ summarization_model : config["summarize_model"],
182
+ eval_method : config["eval_method"],
183
+ eval_threshold : config["eval_threshold"],
184
+ max_docs : config["max_docs"],
185
+ compression_method : config["compression_method"],
186
+ compress_rate : config["compress_rate"],
187
+ force_tokens : config["force_tokens"],
188
+ eval_model : config["eval_model"],
189
+ srv_addr : config["graph"]["address"],
190
+ srv_usr : config["graph"]["username"],
191
+ srv_pwd : config["graph"]["password"],
192
+ openai_api_key : config["openai_api_key"],
193
+ groq_api_key : config["groq_api_key"],
194
+ langsmith_api_key : config["langsmith_api_key"]
195
+ }
196
+
197
+
198
+ def build_config(data):
199
+ """
200
+ Build the config variable using the values inputted by the user
201
+ """
202
+ config = {}
203
+ config["main_llm"] = data[main_llm]
204
+ config["plan_method"] = data[plan_method]
205
+ config["use_detailed_query"] = data[use_detailed_query]
206
+ config["cypher_gen_method"] = data[cypher_gen_method]
207
+ config["validate_cypher"] = data[validate_cypher]
208
+ config["summarize_model"] = data[summarization_model]
209
+ config["eval_method"] = data[eval_method]
210
+ config["eval_threshold"] = data[eval_threshold]
211
+ config["max_docs"] = data[max_docs]
212
+ config["compression_method"] = data[compression_method]
213
+ config["compress_rate"] = data[compress_rate]
214
+ config["force_tokens"] = data[force_tokens]
215
+ config["eval_model"] = data[eval_model]
216
+ config["thread_id"] = "3"
217
+ try:
218
+ neograph = Neo4jGraph(url=data[srv_addr], username=data[srv_usr], password=data[srv_pwd])
219
+ config["graph"] = neograph
220
+ except Exception as e:
221
+ raise gr.Error(f"Error when configuring the neograph server : {e}", duration=5)
222
+ gr.Info("Succesfully updated configuration !", duration=5)
223
+ return {"configurable" : config}
224
+
225
+ with gr.Blocks() as demo:
226
+ with gr.Tab("Config"):
227
+
228
+ ### The config tab
229
+
230
+ gr.Markdown("## Config options setup")
231
+
232
+ gr.Markdown("### API Keys")
233
+
234
+ with gr.Row():
235
+ openai_api_key = gr.Textbox(
236
+ label="OpenAI API Key",
237
+ type="password"
238
+ )
239
+
240
+ groq_api_key = gr.Textbox(
241
+ label="Groq API Key",
242
+ type='password'
243
+ )
244
+
245
+ langsmith_api_key = gr.Textbox(
246
+ label="LangSmith API Key",
247
+ type="password"
248
+ )
249
+
250
+ gr.Markdown('### Planner options')
251
+ with gr.Row():
252
+ main_llm = gr.Dropdown(
253
+ choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768"],
254
+ label="Main LLM",
255
+ info="Choose the LLM which will perform the generation",
256
+ value="gpt-4o"
257
+ )
258
+ with gr.Column(scale=1, min_width=600):
259
+ plan_method = gr.Dropdown(
260
+ choices=["generation", "modification"],
261
+ label="Planning method",
262
+ info="Choose how the main LLM will generate its plan",
263
+ value="modification"
264
+ )
265
+ use_detailed_query = gr.Checkbox(
266
+ label="Detail each plan step",
267
+ info="Detail each plan step before passing it for data query"
268
+ )
269
+
270
+ gr.Markdown("### Data query options")
271
+
272
+ # The options for the data processor
273
+ # TODO : remove the options for summarize and compress and let the user choose them when specifying processing steps
274
+ # (similarly to what is done for custom processing step)
275
+
276
+ with gr.Row():
277
+ with gr.Column(scale=1, min_width=300):
278
+ # Neo4j Server parameters
279
+
280
+ srv_addr = gr.Textbox(
281
+ label="Neo4j server address",
282
+ placeholder="localhost:7687"
283
+ )
284
+ srv_usr = gr.Textbox(
285
+ label="Neo4j username",
286
+ placeholder="neo4j"
287
+ )
288
+ srv_pwd = gr.Textbox(
289
+ label="Neo4j password",
290
+ placeholder="<Password>"
291
+ )
292
+
293
+ with gr.Column(scale=1, min_width=300):
294
+ cypher_gen_method = gr.Dropdown(
295
+ choices=["auto", "guided"],
296
+ label="Cypher generation method",
297
+ )
298
+ validate_cypher = gr.Checkbox(
299
+ label="Validate cypher using graph Schema"
300
+ )
301
+
302
+ summarization_model = gr.Dropdown(
303
+ choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768", "llama3-70b-8192"],
304
+ label="Summarization LLM",
305
+ info="Choose the LLM which will perform the summaries"
306
+ )
307
+
308
+ with gr.Column(scale=1, min_width=300):
309
+ eval_method = gr.Dropdown(
310
+ choices=["binary", "score"],
311
+ label="Retrieved docs evaluation method",
312
+ info="Evaluation method of retrieved docs"
313
+ )
314
+
315
+ eval_model = gr.Dropdown(
316
+ choices = ["gpt-4o", "mixtral-8x7b-32768"],
317
+ label = "Evaluation model",
318
+ info = "The LLM to use to evaluate the relevance of retrieved docs",
319
+ value = "mixtral-8x7b-32768"
320
+ )
321
+
322
+ eval_threshold = gr.Slider(
323
+ minimum=0,
324
+ maximum=1,
325
+ value=0.7,
326
+ label="Eval threshold",
327
+ info="Score above which a doc is considered relevant",
328
+ step=0.01,
329
+ visible=False
330
+ )
331
+
332
+ def eval_method_changed(selection):
333
+ if selection == "score":
334
+ return gr.update(visible=True)
335
+ return gr.update(visible=False)
336
+ eval_method.change(eval_method_changed, inputs=eval_method, outputs=eval_threshold)
337
+
338
+ max_docs= gr.Slider(
339
+ minimum=0,
340
+ maximum = 30,
341
+ value = 15,
342
+ label="Max docs",
343
+ info="Maximum number of docs to be retrieved at each query",
344
+ step=0.01
345
+ )
346
+
347
+ with gr.Column(scale=1, min_width=300):
348
+ compression_method = gr.Dropdown(
349
+ choices=["llm_lingua2", "llm_lingua"],
350
+ label="Compression method",
351
+ value="llm_lingua2"
352
+ )
353
+
354
+ with gr.Row():
355
+
356
+ # Add compression rate configuration with a gr.slider
357
+ compress_rate = gr.Slider(
358
+ minimum = 0,
359
+ maximum = 1,
360
+ value = 0.33,
361
+ label="Compression rate",
362
+ info="Compression rate",
363
+ step = 0.01
364
+ )
365
+
366
+ # Add gr.CheckboxGroup to choose force_tokens
367
+ force_tokens = gr.CheckboxGroup(
368
+ choices=['\n', '?', '.', '!', ','],
369
+ value=[],
370
+ label="Force tokens",
371
+ info="Tokens to keep during compression",
372
+ )
373
+
374
+ with gr.Row():
375
+ btn_update_config = gr.Button(value="Update config")
376
+ load_config_json = gr.Button(value="Load config from JSON")
377
+
378
+ with gr.Row():
379
+ debug_info = gr.Button(value="Print debug info")
380
+
381
+ config_state = gr.State(value={})
382
+
383
+
384
+ btn_update_config.click(
385
+ build_config,
386
+ inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \
387
+ compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs},
388
+ outputs=config_state
389
+ )
390
+ load_config_json.click(
391
+ update_config_display,
392
+ outputs={main_llm, plan_method, use_detailed_query, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, \
393
+ max_docs, compress_rate, compression_method, force_tokens, eval_model, srv_addr, srv_usr, srv_pwd, openai_api_key, langsmith_api_key, groq_api_key}
394
+ ).then(
395
+ build_config,
396
+ inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \
397
+ compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs},
398
+ outputs=config_state
399
+ )
400
+
401
+ # Print config variable in the terminal
402
+ debug_info.click(lambda x : print(x), inputs=config_state)
403
+
404
+ with gr.Tab("Inference"):
405
+ ### Inference tab
406
+
407
+ graph_state = gr.State()
408
+ user_query = gr.Textbox(label = "Your query")
409
+ launch_inference = gr.Button(value="Generate plan")
410
+
411
+ with gr.Row():
412
+ dataframe_plan = gr.Dataframe(visible = False)
413
+ plan_display = gr.HTML(visible = False, label="Generated plan")
414
+
415
+ with gr.Column():
416
+
417
+ # Lets the user modify steps of the plan. Underlying logic not implemented yet
418
+ # TODO : implement this
419
+ with gr.Row():
420
+ select_step_to_modify = gr.Number(visible= False, label="Select a plan step to modify", value=0)
421
+ submit_new_step = gr.Button(visible = False, value="Submit new step")
422
+ enter_new_step = gr.Textbox(visible=False, label="Modify the plan step")
423
+
424
+ with gr.Row():
425
+ human_validation_title = gr.Markdown(visible=False)
426
+ retrieve_more_docs_btn = gr.Button(value="Retrieve more docs", visible=False)
427
+ continue_to_processing_btn = gr.Button(value="Proceed to data processing", visible=False)
428
+
429
+ with gr.Row():
430
+ with gr.Column():
431
+
432
+ process_steps_title = gr.Markdown("#### Data processing steps", visible=False)
433
+ process_steps_nb = gr.Number(label="Number of processing steps", value = 0, precision=0, step = 1, visible=False)
434
+
435
+ def get_process_step_names():
436
+ return ["summarize", "compress", "custom"]
437
+
438
+ # The gr.render decorator allows the code inside the following function to be rerun everytime the 'inputs' variable is modified
439
+ # /!\ All event listeners that use variables defined inside a gr.render function must be defined inside that same function
440
+ # ref : https://www.gradio.app/docs/gradio/render
441
+ @gr.render(inputs=process_steps_nb)
442
+ def processing(nb):
443
+ with gr.Row():
444
+ process_step_names = get_process_step_names()
445
+ dropdowns = []
446
+ textboxes = []
447
+ usable_elements = []
448
+ processing_models = []
449
+ for i in range(nb):
450
+ with gr.Column():
451
+ dropdown = gr.Dropdown(key = f"d{i}", choices=process_step_names, label=f"Data processing step {i+1}")
452
+ dropdowns.append(dropdown)
453
+
454
+ textbox = gr.Textbox(
455
+ key = f"t{i}",
456
+ value="",
457
+ placeholder="Your custom prompt",
458
+ visible=True, min_width=300)
459
+ textboxes.append(textbox)
460
+
461
+ usable_element = gr.Dropdown(
462
+ key = f"u{i}",
463
+ choices = [(j) for j in range(i+1)],
464
+ label="Elements passed to the LLM for this process step",
465
+ multiselect=True,
466
+ )
467
+ usable_elements.append(usable_element)
468
+
469
+ processing_model = gr.Dropdown(
470
+ key = f"m{i}",
471
+ label="The LLM that will execute this step",
472
+ visible=True,
473
+ choices=["gpt-4o", "mixtral-8x7b-32768", "llama3-70b-8182"]
474
+ )
475
+ processing_models.append(processing_model)
476
+
477
+ dropdown.change(
478
+ fn=lambda process_name : [gr.update(visible=(process_name=="custom")), gr.update(visible=(process_name=='custom')), gr.update(visible=(process_name=='custom'))],
479
+ inputs=dropdown,
480
+ outputs=[textbox, usable_element, processing_model]
481
+ )
482
+
483
+ process_data_btn.click(
484
+ execute_processing,
485
+ inputs= dropdowns + textboxes + usable_elements + processing_models + [process_steps_nb, graph_state, config_state],
486
+ outputs={plan_result, processed_docs_state, graph_state, continue_inference_btn, process_steps_nb, process_steps_title, process_data_btn}
487
+ )
488
+
489
+ process_data_btn = gr.Button(value="Process retrieved docs", visible=False)
490
+
491
+ continue_inference_btn = gr.Button(value="Proceed to next plan step", visible=False)
492
+ plan_result = gr.Markdown(visible = True, label="Result of last plan step")
493
+
494
+ with gr.Tab("Retrieved Docs"):
495
+ retrieved_docs_state = gr.State([])
496
+ with gr.Row():
497
+ gr.Markdown("# Retrieved Docs")
498
+ retrieved_docs_btn = gr.Button("Display retrieved docs")
499
+ retrieved_docs_display = gr.Markdown()
500
+
501
+ processed_docs_state = gr.State([])
502
+ with gr.Row():
503
+ gr.Markdown("# Processed Docs")
504
+ processed_docs_btn = gr.Button("Display processed docs")
505
+ processed_docs_display = gr.Markdown()
506
+
507
+ continue_inference_btn.click(
508
+ continue_inference,
509
+ inputs={graph_state, config_state},
510
+ outputs={continue_inference_btn, graph_state, retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, plan_result, retrieved_docs_state}
511
+ )
512
+
513
+ launch_inference.click(
514
+ start_inference,
515
+ inputs={config_state, user_query, openai_api_key, groq_api_key, langsmith_api_key},
516
+ outputs=[dataframe_plan, graph_state]
517
+ ).then(
518
+ update_display,
519
+ inputs=dataframe_plan,
520
+ outputs={plan_display, select_step_to_modify, enter_new_step, submit_new_step, continue_inference_btn}
521
+ )
522
+
523
+ retrieve_more_docs_btn.click(
524
+ retrieve_more_docs,
525
+ inputs={graph_state, config_state},
526
+ outputs={graph_state, human_validation_title, retrieved_docs_display}
527
+ )
528
+ continue_to_processing_btn.click(
529
+ continue_to_processing,
530
+ outputs={retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, process_data_btn, process_steps_nb, process_steps_title}
531
+ )
532
+ retrieved_docs_btn.click(
533
+ fn=lambda docs : format_docs(docs),
534
+ inputs=retrieved_docs_state,
535
+ outputs=retrieved_docs_display
536
+ )
537
+ processed_docs_btn.click(
538
+ fn=lambda docs : format_docs(docs),
539
+ inputs=processed_docs_state,
540
+ outputs=processed_docs_display
541
+ )
542
+
543
+
544
+ test_process_steps = gr.Button(value="Test process steps")
545
+ test_process_steps.click(
546
+ lambda : [gr.update(visible = True), gr.update(visible=True)],
547
+ outputs=[process_steps_nb, process_steps_title]
548
+ )
549
+
550
+
551
+
552
+
553
+
554
+ demo.launch()
555
+
doc_explorer/embeddings_full.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf8ae23f82d734adab5810858bb55c2f13edb06795637b6fd85ada823d722527
3
+ size 55693440
doc_explorer/explorer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from vectorstore import FAISSVectorStore
3
+ from langchain_community.graphs import Neo4jGraph
4
+ import os
5
+ import json
6
+ import html
7
+ import pandas as pd
8
+ import time
9
+
10
+ time.sleep(30)
11
+
12
+ os.environ["http_proxy"] = "185.46.212.98:80"
13
+ os.environ["https_proxy"] = "185.46.212.98:80"
14
+ os.environ["NO_PROXY"] = "localhost"
15
+
16
+ neo4j_graph = Neo4jGraph(
17
+ url=os.getenv("NEO4J_URI", "bolt://localhost:7999"),
18
+ username=os.getenv("NEO4J_USERNAME", "neo4j"),
19
+ password=os.getenv("NEO4J_PASSWORD", "graph_test")
20
+ )
21
+
22
+ # Requires ~1GB RAM
23
+ vector_store = FAISSVectorStore(model_name='Alibaba-NLP/gte-large-en-v1.5', dimension=1024, trust_remote_code=True, embedding_file="/usr/src/app/doc_explorer/embeddings_full.npy")
24
+
25
+ # Get document types from Neo4j database
26
+ def get_document_types():
27
+ query = """
28
+ MATCH (n)
29
+ RETURN DISTINCT labels(n) AS document_type
30
+ """
31
+ result = neo4j_graph.query(query)
32
+ return [row["document_type"][0] for row in result]
33
+
34
+ def search(query, doc_types, use_mmr, lambda_param, top_k):
35
+ results, node_ids = vector_store.similarity_search(
36
+ query,
37
+ k=top_k,
38
+ use_mmr=use_mmr,
39
+ lambda_param=lambda_param if use_mmr else None,
40
+ doc_types=doc_types,
41
+ neo4j_graph=neo4j_graph
42
+ )
43
+
44
+ formatted_results = []
45
+ formatted_choices = []
46
+ for i, result in enumerate(results):
47
+ formatted_results.append(f"{i}. {result['document']} (Score: {result['score']:.4f})")
48
+ formatted_choices.append(f"{i}. {str(result['document'])[:100]} (Score: {result['score']:.4f})")
49
+ return formatted_results, gr.update(choices=formatted_choices, value=[]), node_ids
50
+
51
+ def get_docs_from_ids(graph_data : dict):
52
+ node_ids = [node["id"] for node in graph_data["nodes"]]
53
+ print(node_ids)
54
+ query = """
55
+ MATCH (n)
56
+ WHERE n.id IN $node_ids
57
+ RETURN n.id AS id, n AS doc, labels(n) AS category
58
+ """
59
+
60
+ return neo4j_graph.query(query, {"node_ids" : node_ids}), graph_data["edges"]
61
+
62
+ def get_neighbors_and_graph_data(selected_documents, node_ids, graph_data):
63
+ if not selected_documents:
64
+ return "No documents selected.", json.dumps(graph_data), graph_data
65
+
66
+ selected_indices = [int(doc.split('.')[0]) - 1 for doc in selected_documents]
67
+ selected_node_ids = [node_ids[i] for i in selected_indices]
68
+
69
+ query = """
70
+ MATCH (n)-[r]-(neighbor)
71
+ WHERE n.id IN $node_ids
72
+ RETURN n.id AS source_id, n AS source_text, labels(n) AS source_type,
73
+ neighbor.id AS neighbor_id, neighbor AS neighbor_text,
74
+ labels(neighbor) AS neighbor_type, type(r) AS relationship_type
75
+ """
76
+ results = neo4j_graph.query(query, {"node_ids": selected_node_ids})
77
+
78
+ if not results:
79
+ return "No neighbors found for the selected documents.", "[]"
80
+
81
+ neighbor_info = {}
82
+ node_set = set([node["id"] for node in graph_data["nodes"]])
83
+
84
+ for row in results:
85
+ source_id = row['source_id']
86
+ if source_id not in neighbor_info:
87
+ neighbor_info[source_id] = {
88
+ 'source_type': row["source_type"][0],
89
+ 'source_text': row['source_text'],
90
+ 'neighbors': []
91
+ }
92
+ if source_id not in node_set:
93
+ graph_data["nodes"].append({
94
+ "id": source_id,
95
+ "label": str(row['source_text'])[:30] + "...",
96
+ "group": row['source_type'][0],
97
+ "title": f"<div class='node-tooltip'><h3>{row['source_type'][0]}</h3><p>{row['source_text']}</p></div>",
98
+ })
99
+ node_set.add(source_id)
100
+
101
+ neighbor_info[source_id]['neighbors'].append(
102
+ f"[{row['relationship_type']}] [{row['neighbor_type'][0]}] {str(row['neighbor_text'])[:200]}"
103
+ )
104
+
105
+ if row['neighbor_id'] not in node_set:
106
+ graph_data["nodes"].append({
107
+ "id": row['neighbor_id'],
108
+ "label": str(row['neighbor_text'])[:30] + "...",
109
+ "group": row['neighbor_type'][0],
110
+ "title": f"<div class='node-tooltip'><h3>{row['neighbor_type'][0]}</h3><p>{html.escape(str(row['neighbor_text']))}</p></div>",
111
+ })
112
+ node_set.add(row['neighbor_id'])
113
+
114
+ edge = {
115
+ "from": source_id,
116
+ "to" : row['neighbor_id'],
117
+ "label": row['relationship_type']
118
+ }
119
+ if edge not in graph_data['edges']:
120
+ graph_data['edges'].append(edge)
121
+
122
+ output = []
123
+ for source_id, info in neighbor_info.items():
124
+ output.append(f"Neighbors for: [{info['source_type']}] {str(info['source_text'])[:100]}")
125
+ output.extend(info['neighbors'])
126
+ output.append("\n\n") # Empty line for separation
127
+
128
+ formatted_choices = []
129
+ node_ids = []
130
+ for i, node in enumerate(graph_data['nodes']):
131
+ formatted_choices.append(f"{i+1}. {str(node['label'])})")
132
+ node_ids.append(node['id'])
133
+
134
+ return "\n".join(output), json.dumps(graph_data), graph_data, gr.update(choices=formatted_choices, value=[]), node_ids
135
+
136
+ def save_docs_to_excel(exported_docs : list[dict], exported_relationships : list[dict]):
137
+ cleaned_docs = [dict(doc['doc'], **{'id': doc['id'], 'category': doc['category'][0], "relationships" : ""}) for doc in exported_docs]
138
+ for relationship in exported_relationships:
139
+ for doc in cleaned_docs:
140
+ if doc['id'] == relationship['from']:
141
+ doc["relationships"] += f"[{relationship['label']}] {relationship['to']}\n"
142
+
143
+ df = pd.DataFrame(cleaned_docs)
144
+ df.to_excel("doc_explorer/exported_docs/docs.xlsx")
145
+ return gr.update(value="doc_explorer/exported_docs/docs.xlsx", visible=True)
146
+
147
+ # JavaScript code for graph visualization
148
+ js_code = """
149
+ function(graph_data_str) {
150
+ if (!graph_data_str) return;
151
+ const container = document.getElementById('graph-container');
152
+ container.innerHTML = '';
153
+ let data;
154
+ try {
155
+ data = JSON.parse(graph_data_str);
156
+ } catch (error) {
157
+ console.error("Failed to parse graph data:", error);
158
+ container.innerHTML = "Error: Failed to load graph data.";
159
+ return;
160
+ }
161
+
162
+ data.nodes.forEach(node => {
163
+ const div = document.createElement('div');
164
+ div.innerHTML = node.title;
165
+ node.title = div.firstChild;
166
+ });
167
+
168
+ const nodes = new vis.DataSet(data.nodes);
169
+ const edges = new vis.DataSet(data.edges);
170
+ const options = {
171
+ nodes: {
172
+ shape: 'dot',
173
+ size: 16,
174
+ font: {
175
+ size: 12,
176
+ color: '#000000'
177
+ },
178
+ borderWidth: 2
179
+ },
180
+ edges: {
181
+ width: 1,
182
+ font: {
183
+ size: 10,
184
+ align: 'middle'
185
+ },
186
+ color: { color: '#7A7A7A', hover: '#2B7CE9' }
187
+ },
188
+ physics: {
189
+ forceAtlas2Based: {
190
+ gravitationalConstant: -26,
191
+ centralGravity: 0.005,
192
+ springLength: 230,
193
+ springConstant: 0.18
194
+ },
195
+ maxVelocity: 146,
196
+ solver: 'forceAtlas2Based',
197
+ timestep: 0.35,
198
+ stabilization: { iterations: 150 }
199
+ },
200
+ interaction: {
201
+ hover: true,
202
+ tooltipDelay: 200
203
+ }
204
+ };
205
+ const network = new vis.Network(container, { nodes: nodes, edges: edges }, options);
206
+ }
207
+ """
208
+
209
+ head = """
210
+ <script type="text/javascript" src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
211
+ <link href="https://unpkg.com/vis-network/styles/vis-network.min.css" rel="stylesheet" type="text/css" />
212
+ """
213
+
214
+ custom_css = """
215
+ #graph-container {
216
+ border: 1px solid #ddd;
217
+ border-radius: 4px;
218
+ }
219
+ .vis-tooltip {
220
+ font-family: Arial, sans-serif;
221
+ padding: 10px;
222
+ border-radius: 4px;
223
+ background-color: rgba(255, 255, 255, 0.9);
224
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
225
+ max-width: 300px;
226
+ color: #333;
227
+ word-wrap: break-word;
228
+ overflow-wrap: break-word;
229
+ }
230
+ .node-tooltip {
231
+ width: 100%;
232
+ }
233
+ .node-tooltip h3 {
234
+ margin: 0 0 5px 0;
235
+ font-size: 14px;
236
+ color: #333;
237
+ }
238
+ .node-tooltip p {
239
+ margin: 0;
240
+ font-size: 12px;
241
+ color: #666;
242
+ white-space: normal;
243
+ }
244
+ """
245
+
246
+
247
+ with gr.Blocks(head=head, css=custom_css) as demo:
248
+ with gr.Tab("Search"):
249
+
250
+ gr.Markdown("# Document Search Engine")
251
+ gr.Markdown("Enter a query to search for similar documents. You can filter by document type and use MMR for diverse results.")
252
+
253
+ with gr.Row():
254
+ with gr.Column(scale=3):
255
+ query_input = gr.Textbox(label="Enter your query")
256
+ doc_type_input = gr.Dropdown(choices=get_document_types(), label="Select document type", multiselect=True)
257
+ with gr.Column(scale=2):
258
+ mmr_input = gr.Checkbox(label="Use MMR for diverse results")
259
+ lambda_input = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Lambda parameter (MMR diversity)", visible=False)
260
+ top_k_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of results")
261
+
262
+ search_button = gr.Button("Search")
263
+ results_output = gr.Textbox(label="Search Results")
264
+
265
+ selected_documents = gr.Dropdown(label="Select documents to view their neighbors", choices=[], multiselect=True, interactive=True)
266
+
267
+ with gr.Row():
268
+ neighbor_search_button = gr.Button("Find Neighbors")
269
+ send_to_export = gr.Button("Send docs to export Tab")
270
+
271
+ neighbors_output = gr.Textbox(label="Document Neighbors")
272
+
273
+ graph_data_state = gr.State({"nodes": [], "edges": []})
274
+ graph_data_str = gr.Textbox(visible=False)
275
+ graph_container = gr.HTML('<div id="graph-container" style="height: 600px;"> Hey ! </div>')
276
+
277
+ node_ids = gr.State([])
278
+ exported_docs = gr.State([])
279
+ exported_relationships = gr.State([])
280
+
281
+ def update_lambda_visibility(use_mmr):
282
+ return gr.update(visible=use_mmr)
283
+
284
+ mmr_input.change(fn=update_lambda_visibility, inputs=mmr_input, outputs=lambda_input)
285
+
286
+ search_button.click(
287
+ fn=search,
288
+ inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input],
289
+ outputs=[results_output, selected_documents, node_ids]
290
+ )
291
+
292
+ neighbor_search_button.click(
293
+ fn=get_neighbors_and_graph_data,
294
+ inputs=[selected_documents, node_ids, graph_data_state],
295
+ outputs=[neighbors_output, graph_data_str, graph_data_state, selected_documents, node_ids]
296
+ ).then(
297
+ fn=None,
298
+ inputs=graph_data_str,
299
+ outputs=None,
300
+ js=js_code,
301
+ )
302
+
303
+ send_to_export.click(
304
+ fn=get_docs_from_ids,
305
+ inputs=graph_data_state,
306
+ outputs=[exported_docs, exported_relationships]
307
+ )
308
+ # gr.Examples(
309
+ # examples=[
310
+ # ["What is machine learning?", "Article", True, 0.5, 5],
311
+ # ["How to implement a neural network?", "Tutorial", False, 0.5, 3],
312
+ # ["Latest advancements in NLP", "Research Paper", True, 0.7, 10]
313
+ # ],
314
+ # inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input]
315
+ # )
316
+ with gr.Tab("Export"):
317
+ with gr.Row():
318
+ exported_docs_btn = gr.Button("Display exported docs")
319
+ exported_excel_btn = gr.Button("Export to excel")
320
+ exported_excel = gr.File(visible=False)
321
+
322
+ exported_docs_display = gr.Markdown(visible=False)
323
+
324
+ exported_docs_btn.click(
325
+ fn= lambda docs: gr.update(value='\n\n'.join([f"[{doc['category'][0]}]\n{doc['doc']}\n\n" for doc in docs]), visible=True),
326
+ inputs=exported_docs,
327
+ outputs=exported_docs_display
328
+ )
329
+ exported_excel_btn.click(
330
+ fn=save_docs_to_excel,
331
+ inputs=[exported_docs, exported_relationships],
332
+ outputs=exported_excel
333
+ )
334
+
335
+ demo.launch()
doc_explorer/exported_docs/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+ a
doc_explorer/vectorstore.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+ from typing import List, Optional, Tuple
5
+ from langchain_community.graphs import Neo4jGraph
6
+ import pickle
7
+
8
+ class FAISSVectorStore:
9
+ def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False):
10
+ self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None
11
+ self.index = faiss.IndexFlatIP(dimension)
12
+ self.dimension = dimension
13
+ if embedding_file:
14
+ self.load_embeddings(embedding_file)
15
+
16
+ def load_embeddings(self, file_path: str):
17
+ if file_path.endswith('.pkl'):
18
+ with open(file_path, 'rb') as f:
19
+ embeddings = pickle.load(f)
20
+ elif file_path.endswith('.npy'):
21
+ embeddings = np.load(file_path)
22
+ else:
23
+ raise ValueError("Unsupported file format. Use .pkl or .npy")
24
+
25
+ self.add_embeddings(embeddings)
26
+
27
+ def add_embeddings(self, embeddings: np.ndarray):
28
+ faiss.normalize_L2(embeddings)
29
+ self.index.add(embeddings)
30
+
31
+ def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None):
32
+ query_vector = self.model.encode([query])
33
+ faiss.normalize_L2(query_vector)
34
+
35
+ if use_mmr:
36
+ return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types)
37
+ else:
38
+ return self._simple_search(query_vector, k, neo4j_graph, doc_types)
39
+
40
+ def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]:
41
+ distances, indices = self.index.search(query_vector, k)
42
+
43
+ results = []
44
+ results_idx = []
45
+ for i, idx in enumerate(indices[0]):
46
+ document = self._get_text_by_index(neo4j_graph, idx, doc_types)
47
+ if document is not None:
48
+ results.append({
49
+ 'document': document,
50
+ 'score': distances[0][i]
51
+ })
52
+ results_idx.append(idx)
53
+
54
+ return results, results_idx
55
+
56
+ def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]:
57
+ initial_k = min(k * 2, self.index.ntotal)
58
+ distances, indices = self.index.search(query_vector, initial_k)
59
+
60
+ # Reconstruct embeddings for the initial results
61
+ initial_embeddings = self._reconstruct_embeddings(indices[0])
62
+
63
+ selected_indices = []
64
+ unselected_indices = list(range(len(indices[0])))
65
+
66
+ for _ in range(min(k, len(indices[0]))):
67
+ mmr_scores = []
68
+ for i in unselected_indices:
69
+ if not selected_indices:
70
+ mmr_scores.append((i, distances[0][i]))
71
+ else:
72
+ embedding_i = initial_embeddings[i]
73
+ redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices)
74
+ mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy))
75
+
76
+ selected_idx = max(mmr_scores, key=lambda x: x[1])[0]
77
+ selected_indices.append(selected_idx)
78
+ unselected_indices.remove(selected_idx)
79
+
80
+ results = []
81
+ results_idx = []
82
+ for idx in selected_indices:
83
+ document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types)
84
+ if document is not None:
85
+ results.append({
86
+ 'document': document,
87
+ 'score': distances[0][idx]
88
+ })
89
+ results_idx.append(idx)
90
+
91
+ return results, results_idx
92
+
93
+ def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray:
94
+ return self.index.reconstruct_batch(indices)
95
+
96
+ def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
97
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
98
+
99
+ def _get_text_by_index(self, neo4j_graph, index, doc_types):
100
+ if doc_types is None:
101
+ query = f"""
102
+ MATCH (n)
103
+ WHERE n.id = $index
104
+ RETURN n AS document, labels(n) AS document_type, n.id AS node_id
105
+ """
106
+ result = neo4j_graph.query(query, {"index": index})
107
+ else:
108
+ for doc_type in doc_types:
109
+ query = f"""
110
+ MATCH (n:{doc_type})
111
+ WHERE n.id = $index
112
+ RETURN n AS document, labels(n) AS document_type, n.id AS node_id
113
+ """
114
+ result = neo4j_graph.query(query, {"index": index})
115
+ if result:
116
+ break
117
+
118
+ if result:
119
+ return f"[{result[0]['document_type'][0]}] {result[0]['document']}"
120
+ return None
121
+
122
+
123
+ # Example usage
124
+ if __name__ == "__main__":
125
+ # Initialize the vector store with embedding file
126
+ vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl") # or .npy file
127
+
128
+ # Initialize Neo4jGraph (replace with your actual Neo4j connection details)
129
+ neo4j_graph = Neo4jGraph(
130
+ url="bolt://localhost:7687",
131
+ username="neo4j",
132
+ password="password"
133
+ )
134
+
135
+ # Perform a similarity search with and without MMR
136
+ query = "How to start a long journey"
137
+ results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph)
138
+ results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph)
139
+
140
+ # Print the results
141
+ print(f"Top 5 similar texts for query: '{query}' (without MMR)")
142
+ for i, result in enumerate(results_simple, 1):
143
+ print(f"{i}. Text: {result['text']}")
144
+ print(f" Score: {result['score']}")
145
+ print()
146
+
147
+ print(f"Top 5 similar texts for query: '{query}' (with MMR)")
148
+ for i, result in enumerate(results_mmr, 1):
149
+ print(f"{i}. Text: {result['text']}")
150
+ print(f" Score: {result['score']}")
151
+ print()
flagged/log.csv ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output,flag,username,timestamp
2
+ "'
3
+ <style>
4
+ table {
5
+ border-collapse: collapse;
6
+ width: 100%;
7
+ }
8
+ th, td {
9
+ border: 1px solid black;
10
+ padding: 8px;
11
+ text-align: left;
12
+ vertical-align: top;
13
+ white-space: pre-wrap;
14
+ max-width: 300px;
15
+ max-height: 100px;
16
+ overflow-y: auto;
17
+ }
18
+ th {
19
+ background-color: #f2f2f2;
20
+ }
21
+ </style>
22
+ <table border=""1"" class=""dataframe"">
23
+ <thead>
24
+ <tr style=""text-align: right;"">
25
+ <th>Column1</th>
26
+ <th>Column2</th>
27
+ </tr>
28
+ </thead>
29
+ <tbody>
30
+ <tr>
31
+ <td>Line 1\nLine 2\nLine 3</td>
32
+ <td>Short text</td>
33
+ </tr>
34
+ <tr>
35
+ <td>Single line</td>
36
+ <td>Very long text that goes on and on and might need scrolling in the cell</td>
37
+ </tr>
38
+ </tbody>
39
+ </table>",,,2024-07-29 15:18:50.387842
images/flowchart_graphrag.png ADDED
images/flowchart_graphrag_dark.png ADDED
images/flowchart_graphrag_final.png ADDED
images/graph_png.png ADDED
ki_gen/data_processor.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_groq import ChatGroq
8
+ from langgraph.graph import StateGraph
9
+ from llmlingua import PromptCompressor
10
+
11
+ from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
12
+
13
+
14
+
15
+
16
+ # compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
17
+
18
+ ## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
19
+ ## Before that, you need to pip install optimum auto-gptq
20
+ # llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
21
+
22
+
23
+
24
+ # Requires ~2GB of RAM
25
+ def get_llm_lingua(compress_method:str = "llm_lingua2"):
26
+
27
+ # Requires ~2GB memory
28
+ if compress_method == "llm_lingua2":
29
+ llm_lingua2 = PromptCompressor(
30
+ model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
31
+ use_llmlingua2=True,
32
+ device_map="cpu"
33
+ )
34
+ return llm_lingua2
35
+
36
+ # Requires ~8GB memory
37
+ elif compress_method == "llm_lingua":
38
+ llm_lingua = PromptCompressor(
39
+ model_name="microsoft/phi-2",
40
+ device_map="cpu"
41
+ )
42
+ return llm_lingua
43
+ raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
44
+
45
+
46
+
47
+ def compress(state: DocProcessorState, config: ConfigSchema):
48
+ """
49
+ This node compresses last processing result for each doc using llm_lingua
50
+ """
51
+ doc_process_histories = state["docs_in_processing"]
52
+ llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
53
+ for doc_process_history in doc_process_histories:
54
+ doc_process_history.append(llm_lingua.compress_prompt(
55
+ doc = str(doc_process_history[-1]),
56
+ rate=config["configurable"].get("compress_rate") or 0.33,
57
+ force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
58
+ )["compressed_prompt"]
59
+ )
60
+
61
+ return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
62
+
63
+ def summarize_docs(state: DocProcessorState, config: ConfigSchema):
64
+ """
65
+ This node summarizes all docs in state["valid_docs"]
66
+ """
67
+
68
+ prompt = """You are a 3GPP standardization expert.
69
+ Summarize the provided document in simple technical English for other experts in the field.
70
+
71
+ Document:
72
+ {document}"""
73
+ sysmsg = ChatPromptTemplate.from_messages([
74
+ ("system", prompt)
75
+ ])
76
+ model = config["configurable"].get("summarize_model") or "mixtral-8x7b-32768"
77
+ doc_process_histories = state["docs_in_processing"]
78
+ if model == "gpt-4o":
79
+ llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
80
+ else:
81
+ llm_summarize = ChatGroq(model=model)
82
+ summarize_chain = sysmsg | llm_summarize | StrOutputParser()
83
+
84
+ for doc_process_history in doc_process_histories:
85
+ doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
86
+
87
+ return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
88
+
89
+ def custom_process(state: DocProcessorState):
90
+ """
91
+ Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
92
+ processing_model : the LLM which will perform the processing
93
+ context : the previous processing results to send as context to the LLM
94
+ user_prompt : the prompt/task which will be appended to the context before sending to the LLM
95
+ """
96
+
97
+ processing_params = state["process_steps"][state["current_process_step"]]
98
+ model = processing_params.get("processing_model") or "mixtral-8x7b-32768"
99
+ user_prompt = processing_params["prompt"]
100
+ context = processing_params.get("context") or [0]
101
+ doc_process_histories = state["docs_in_processing"]
102
+ if not isinstance(context, list):
103
+ context = [context]
104
+
105
+ processing_chain = get_model(model=model) | StrOutputParser()
106
+
107
+ for doc_process_history in doc_process_histories:
108
+ context_str = ""
109
+ for i, context_element in enumerate(context):
110
+ context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
111
+ doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
112
+
113
+ return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
114
+
115
+ def final(state: DocProcessorState):
116
+ """
117
+ A node to store the final results of processing in the 'valid_docs' field
118
+ """
119
+ return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
120
+
121
+ # TODO : remove this node and use conditional entry point instead
122
+ def get_process_steps(state: DocProcessorState, config: ConfigSchema):
123
+ """
124
+ Dummy node
125
+ """
126
+ # if not process_steps:
127
+ # process_steps = eval(input("Enter processing steps: "))
128
+ return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
129
+
130
+
131
+ def next_processor_step(state: DocProcessorState):
132
+ """
133
+ Conditional edge function to go to next processing step
134
+ """
135
+ process_steps = state["process_steps"]
136
+ if state["current_process_step"] < len(process_steps):
137
+ step = process_steps[state["current_process_step"]]
138
+ if isinstance(step, dict):
139
+ step = "custom"
140
+ else:
141
+ step = "final"
142
+
143
+ return step
144
+
145
+
146
+ def build_data_processor_graph(memory):
147
+ """
148
+ Builds the data processor graph
149
+ """
150
+
151
+ graph_builder_doc_processor = StateGraph(DocProcessorState)
152
+
153
+ graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
154
+ graph_builder_doc_processor.add_node("summarize", summarize_docs)
155
+ graph_builder_doc_processor.add_node("compress", compress)
156
+ graph_builder_doc_processor.add_node("custom", custom_process)
157
+ graph_builder_doc_processor.add_node("final", final)
158
+
159
+ graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
160
+ graph_builder_doc_processor.add_conditional_edges(
161
+ "get_process_steps",
162
+ next_processor_step,
163
+ {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
164
+ )
165
+ graph_builder_doc_processor.add_conditional_edges(
166
+ "summarize",
167
+ next_processor_step,
168
+ {"compress" : "compress", "final": "final", "custom" : "custom"}
169
+ )
170
+ graph_builder_doc_processor.add_conditional_edges(
171
+ "compress",
172
+ next_processor_step,
173
+ {"summarize" : "summarize", "final": "final", "custom" : "custom"}
174
+ )
175
+ graph_builder_doc_processor.add_conditional_edges(
176
+ "custom",
177
+ next_processor_step,
178
+ {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
179
+ )
180
+ graph_builder_doc_processor.add_edge("final", "__end__")
181
+
182
+ graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
183
+ return graph_doc_processor
ki_gen/data_retriever.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import re
5
+ from random import shuffle, sample
6
+
7
+ from langchain_groq import ChatGroq
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import HumanMessage
10
+ from langchain_community.graphs import Neo4jGraph
11
+ from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
12
+ from langchain_core.output_parsers import StrOutputParser
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.pydantic_v1 import BaseModel, Field
15
+ from langchain_groq import ChatGroq
16
+
17
+ from langgraph.graph import StateGraph
18
+
19
+ from llmlingua import PromptCompressor
20
+
21
+ from ki_gen.prompts import (
22
+ CYPHER_GENERATION_PROMPT,
23
+ CONCEPT_SELECTION_PROMPT,
24
+ BINARY_GRADER_PROMPT,
25
+ SCORE_GRADER_PROMPT,
26
+ RELEVANT_CONCEPTS_PROMPT,
27
+ )
28
+ from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
29
+
30
+
31
+
32
+
33
+ def extract_cypher(text: str) -> str:
34
+ """Extract Cypher code from a text.
35
+
36
+ Args:
37
+ text: Text to extract Cypher code from.
38
+
39
+ Returns:
40
+ Cypher code extracted from the text.
41
+ """
42
+ # The pattern to find Cypher code enclosed in triple backticks
43
+ pattern_1 = r"```cypher\n(.*?)```"
44
+ pattern_2 = r"```\n(.*?)```"
45
+
46
+ # Find all matches in the input text
47
+ matches_1 = re.findall(pattern_1, text, re.DOTALL)
48
+ matches_2 = re.findall(pattern_2, text, re.DOTALL)
49
+ return [
50
+ matches_1[0] if matches_1 else text,
51
+ matches_2[0] if matches_2 else text,
52
+ text
53
+ ]
54
+
55
+ def get_cypher_gen_chain(model: str = "openai"):
56
+ """
57
+ Returns cypher gen chain using specified model for generation
58
+ This is used when the 'auto' cypher generation method has been configured
59
+ """
60
+
61
+ if model=="openai":
62
+ llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
63
+ else:
64
+ llm_cypher_gen = ChatGroq(model = "mixtral-8x7b-32768")
65
+ cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
66
+ return cypher_gen_chain
67
+
68
+ def get_concept_selection_chain(model: str = "openai"):
69
+ """
70
+ Returns a chain to select the most relevant topic using specified model for generation.
71
+ This is used when the 'guided' cypher generation method has been configured
72
+ """
73
+
74
+ if model == "openai":
75
+ llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
76
+ else:
77
+ llm_topic_selection = ChatGroq(model="llama3-70b-8192", max_tokens=8192)
78
+ print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
79
+ topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
80
+ return topic_selection_chain
81
+
82
+ def get_concepts(graph: Neo4jGraph):
83
+ concept_cypher = "MATCH (c:Concept) return c"
84
+ if isinstance(graph, Neo4jGraph):
85
+ concepts = graph.query(concept_cypher)
86
+ else:
87
+ user_input = input("Topics : ")
88
+ concepts = eval(user_input)
89
+
90
+ concepts_name = [concept['c']['name'] for concept in concepts]
91
+ return concepts_name
92
+
93
+ def get_related_concepts(graph: Neo4jGraph, question: str):
94
+ concepts = get_concepts(graph)
95
+ llm = get_model(model='gpt-4o')
96
+ print(f"this is the llm variable : {llm}")
97
+ def parse_answer(llm_answer : str):
98
+ print(f"This the llm_answer : {llm_answer}")
99
+ return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
100
+ related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
101
+
102
+ related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
103
+
104
+ # We clean up the list we received from the LLM in case there were some hallucinations
105
+ related_concepts_cleaned = []
106
+ for related_concept in related_concepts_raw:
107
+ # If the concept returned from the LLM is in the list we keep it
108
+ if related_concept in concepts:
109
+ related_concepts_cleaned.append(related_concept)
110
+ else:
111
+ # The LLM sometimes only forgets a few words from the concept name
112
+ # We check if the generated concept is a substring of an existing one and if it is the case add it to the list
113
+ for concept in concepts:
114
+ if related_concept in concept:
115
+ related_concepts_cleaned.append(concept)
116
+ break
117
+
118
+ # TODO : Add concepts found via similarity search
119
+ return related_concepts_cleaned
120
+
121
+ def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
122
+ concept_string = ""
123
+ for concept in concept_list:
124
+ concept_description_query = f"""
125
+ MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
126
+ """
127
+ concept_description = graph.query(concept_description_query)[0]['c.description']
128
+ concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
129
+ return concept_string
130
+
131
+ def get_global_concepts(graph: Neo4jGraph):
132
+ concept_cypher = "MATCH (gc:GlobalConcept) return gc"
133
+ if isinstance(graph, Neo4jGraph):
134
+ concepts = graph.query(concept_cypher)
135
+ else:
136
+ user_input = input("Topics : ")
137
+ concepts = eval(user_input)
138
+
139
+ concepts_name = [concept['gc']['name'] for concept in concepts]
140
+ return concepts_name
141
+
142
+ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
143
+ """
144
+ The node where the cypher is generated
145
+ """
146
+
147
+ graph = config["configurable"].get("graph")
148
+ question = state['query']
149
+ related_concepts = get_related_concepts(graph, question)
150
+ cyphers = []
151
+
152
+ if config["configurable"].get("cypher_gen_method") == 'auto':
153
+ cypher_gen_chain = get_cypher_gen_chain()
154
+ cyphers = cypher_gen_chain.invoke({
155
+ "schema": graph.schema,
156
+ "question": question,
157
+ "concepts": related_concepts
158
+ })
159
+
160
+
161
+ if config["configurable"].get("cypher_gen_method") == 'guided':
162
+ concept_selection_chain = get_concept_selection_chain()
163
+ print(f"Concept selection chain is : {concept_selection_chain}")
164
+ selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
165
+ print(f"Selected topic are : {selected_topic}")
166
+ cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
167
+ print(f"Cyphers are : {cyphers}")
168
+
169
+
170
+ if config["configurable"].get("validate_cypher"):
171
+ corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
172
+ cypher_corrector = CypherQueryCorrector(corrector_schema)
173
+ cyphers = [cypher_corrector(cypher) for cypher in cyphers]
174
+
175
+ return {"cyphers" : cyphers}
176
+
177
+ def generate_cypher_from_topic(selected_concept: str, plan_step: int):
178
+ """
179
+ Helper function used when the 'guided' cypher generation method has been configured
180
+ """
181
+
182
+ print(f"L.176 PLAN STEP : {plan_step}")
183
+ cypher_el = "(n) return n.title, n.description"
184
+ match plan_step:
185
+ case 0:
186
+ cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description"
187
+ case 1:
188
+ cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
189
+ case 2:
190
+ cypher_el = "(ki:KeyIssue) RETURN ki.description"
191
+ return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
192
+
193
+ def get_docs(state:DocRetrieverState, config:ConfigSchema):
194
+ """
195
+ This node retrieves docs from the graph using the generated cypher
196
+ """
197
+ graph = config["configurable"].get("graph")
198
+ output = []
199
+ if graph is not None:
200
+ for cypher in state["cyphers"]:
201
+ try:
202
+ output = graph.query(cypher)
203
+ break
204
+ except Exception as e:
205
+ print("Failed to retrieve docs : {e}")
206
+
207
+ # Clean up the docs we received as there may be duplicates depending on the cypher query
208
+ all_docs = []
209
+ for doc in output:
210
+ unwinded_doc = {}
211
+ for key in doc:
212
+ if isinstance(doc[key], dict):
213
+ all_docs.append(doc[key])
214
+ else:
215
+ unwinded_doc.update({key: doc[key]})
216
+ if unwinded_doc:
217
+ all_docs.append(unwinded_doc)
218
+
219
+
220
+ filtered_docs = []
221
+ for doc in all_docs:
222
+ if doc not in filtered_docs:
223
+ filtered_docs.append(doc)
224
+
225
+ return {"docs": filtered_docs}
226
+
227
+
228
+
229
+
230
+
231
+ # Data model
232
+ class GradeDocumentsBinary(BaseModel):
233
+ """Binary score for relevance check on retrieved documents."""
234
+
235
+ binary_score: str = Field(
236
+ description="Documents are relevant to the question, 'yes' or 'no'"
237
+ )
238
+
239
+ # LLM with function call
240
+ # llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
241
+
242
+ def get_binary_grader(model="mixtral-8x7b-32768"):
243
+ """
244
+ Returns a binary grader to evaluate relevance of documents using specified model for generation
245
+ This is used when the 'binary' evaluation method has been configured
246
+ """
247
+
248
+
249
+ if model == "gpt-4o":
250
+ llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
251
+ else:
252
+ llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
253
+ structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
254
+ retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
255
+ return retrieval_grader_binary
256
+
257
+
258
+ class GradeDocumentsScore(BaseModel):
259
+ """Score for relevance check on retrieved documents."""
260
+
261
+ score: float = Field(
262
+ description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
263
+ )
264
+
265
+ def get_score_grader(model="mixtral-8x7b-32768"):
266
+ """
267
+ Returns a score grader to evaluate relevance of documents using specified model for generation
268
+ This is used when the 'score' evaluation method has been configured
269
+ """
270
+ if model == "gpt-4o":
271
+ llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
272
+ else:
273
+ llm_grader_score = ChatGroq(model="mixtral-8x7b-32768", temperature = 0)
274
+ structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
275
+ retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
276
+ return retrieval_grader_score
277
+
278
+
279
+ def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="mixtral-8x7b-32768"):
280
+ '''
281
+ doc : the document to evaluate
282
+ query : the query to which to doc shoud be relevant
283
+ method : "binary" or "score"
284
+ threshold : for "score" method, score above which a doc is considered relevant
285
+ '''
286
+ if method == "binary":
287
+ retrieval_grader_binary = get_binary_grader(model=eval_model)
288
+ return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
289
+ elif method == "score":
290
+ retrieval_grader_score = get_score_grader(model=eval_model)
291
+ score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
292
+ if score is not None:
293
+ return score if score >= threshold else 0
294
+ else:
295
+ # Couldn't parse score, marking document as relevant by default
296
+ return 1
297
+ else:
298
+ raise ValueError("Invalid method")
299
+
300
+ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
301
+ """
302
+ This node performs evaluation of the retrieved docs and
303
+ """
304
+
305
+ eval_method = config["configurable"].get("eval_method") or "binary"
306
+ MAX_DOCS = config["configurable"].get("max_docs") or 15
307
+ valid_doc_scores = []
308
+
309
+ for doc in sample(state["docs"], min(25, len(state["docs"]))):
310
+ score = eval_doc(
311
+ doc=format_doc(doc),
312
+ query=state["query"],
313
+ method=eval_method,
314
+ threshold=config["configurable"].get("eval_threshold") or 0.7,
315
+ eval_model = config["configurable"].get("eval_model") or "mixtral-8x7b-32768"
316
+ )
317
+ if score:
318
+ valid_doc_scores.append((doc, score))
319
+
320
+ if eval_method == 'score':
321
+ # Get at most MAX_DOCS items with the highest score if score method was used
322
+ valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
323
+ valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
324
+ else:
325
+ # Get at mots MAX_DOCS items at random if binary method was used
326
+ shuffle(valid_doc_scores)
327
+ valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
328
+
329
+ return {"valid_docs": valid_docs + (state["valid_docs"] or [])}
330
+
331
+
332
+
333
+ def build_data_retriever_graph(memory):
334
+ """
335
+ Builds the data_retriever graph
336
+ """
337
+ graph_builder_doc_retriever = StateGraph(DocRetrieverState)
338
+
339
+ graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher)
340
+ graph_builder_doc_retriever.add_node("get_docs", get_docs)
341
+ graph_builder_doc_retriever.add_node("eval_docs", eval_docs)
342
+
343
+
344
+ graph_builder_doc_retriever.add_edge("__start__", "generate_cypher")
345
+ graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs")
346
+ graph_builder_doc_retriever.add_edge("get_docs", "eval_docs")
347
+ graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
348
+
349
+ graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
350
+
351
+ return graph_doc_retriever
ki_gen/planner.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from typing import Annotated
5
+ from typing_extensions import TypedDict
6
+
7
+ from langchain_groq import ChatGroq
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
+ from langchain_community.graphs import Neo4jGraph
11
+
12
+ from langgraph.graph import StateGraph
13
+ from langgraph.graph import add_messages
14
+
15
+ from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
16
+ from ki_gen.data_retriever import build_data_retriever_graph
17
+ from ki_gen.data_processor import build_data_processor_graph
18
+ from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState
19
+
20
+
21
+ ##########################################################################
22
+ ###### NODES DEFINITION ######
23
+ ##########################################################################
24
+
25
+ def validate_node(state: State):
26
+ """
27
+ This node inserts the plan validation prompt.
28
+ """
29
+ prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise.
30
+ If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct"."""
31
+ output = HumanMessage(content=prompt)
32
+ return {"messages" : [output]}
33
+
34
+
35
+ # Wrappers to call LLMs on the state messsages field
36
+ def chatbot_llama(state: State):
37
+ llm_llama = ChatGroq(model="llama3-70b-8192")
38
+ return {"messages" : [llm_llama.invoke(state["messages"])]}
39
+
40
+ def chatbot_mixtral(state: State):
41
+ llm_mixtral = ChatGroq(model="mixtral-8x7b-32768")
42
+ return {"messages" : [llm_mixtral.invoke(state["messages"])]}
43
+
44
+ def chatbot_openai(state: State):
45
+ llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
46
+ return {"messages" : [llm_openai.invoke(state["messages"])]}
47
+
48
+ chatbots = {"gpt-4o" : chatbot_openai,
49
+ "mixtral-8x7b-32768" : chatbot_mixtral,
50
+ "llama3-70b-8192" : chatbot_llama
51
+ }
52
+
53
+
54
+ def parse_plan(state: State):
55
+ """
56
+ This node parses the generated plan and writes in the 'store_plan' field of the state
57
+ """
58
+ plan = state["messages"][-3].content
59
+ store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:]
60
+ try:
61
+ store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0]
62
+ except Exception as e:
63
+ print(f"Error while removing <END_OF_PLAN> : {e}")
64
+
65
+ return {"store_plan" : store_plan}
66
+
67
+ def detail_step(state: State, config: ConfigSchema):
68
+ """
69
+ This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
70
+ """
71
+ print("test")
72
+ print(state)
73
+
74
+ if 'current_plan_step' in state.keys():
75
+ print("all good chief")
76
+ else:
77
+ state["current_plan_step"] = None
78
+
79
+ current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end
80
+ if config["configurable"].get("use_detailed_query"):
81
+ prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
82
+ Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""")
83
+ query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm"))
84
+ return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query}
85
+
86
+ return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []}
87
+
88
+ def get_detailed_query(context : list, model : str = "mixtral-8x7b-32768"):
89
+ """
90
+ Simple helper function for the detail_step node
91
+ """
92
+ if model == 'gpt-4o':
93
+ llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
94
+ else:
95
+ llm = ChatGroq(model=model)
96
+ return llm.invoke(context)
97
+
98
+ def concatenate_data(state: State):
99
+ """
100
+ This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
101
+ """
102
+ prompt = f"""#########TECHNICAL INFORMATION ############
103
+ {str(state["valid_docs"])}
104
+
105
+ ########END OF TECHNICAL INFORMATION#######
106
+
107
+ Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan :
108
+ {state['store_plan'][state['current_plan_step']]}
109
+ """
110
+
111
+ return {"messages": [HumanMessage(content=prompt)]}
112
+
113
+
114
+ def human_validation(state: HumanValidationState) -> HumanValidationState:
115
+ """
116
+ Dummy node to interrupt before
117
+ """
118
+ return {'process_steps' : []}
119
+
120
+ def generate_ki(state: State):
121
+ """
122
+ This node inserts the prompt to begin Key Issues generation
123
+ """
124
+ print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}")
125
+
126
+ prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
127
+ {state['store_plan'][state['current_plan_step'] + 1]}"""
128
+
129
+ return {"messages" : [HumanMessage(content=prompt)]}
130
+
131
+ def detail_ki(state: State):
132
+ """
133
+ This node inserts the last prompt to detail the generated Key Issues
134
+ """
135
+ prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
136
+ {state['store_plan'][state['current_plan_step'] + 2]}"""
137
+
138
+ return {"messages" : [HumanMessage(content=prompt)]}
139
+
140
+ ##########################################################################
141
+ ###### CONDITIONAL EDGE FUNCTIONS ######
142
+ ##########################################################################
143
+
144
+ def validate_plan(state: State):
145
+ """
146
+ Whether to regenerate the plan or to parse it
147
+ """
148
+ if "messages" in state and state["messages"][-1].content in ["My plan is correct.","My plan is correct"]:
149
+ return "parse"
150
+ return "validate"
151
+
152
+ def next_plan_step(state: State, config: ConfigSchema):
153
+ """
154
+ Proceed to next plan step (either generate KI or retrieve more data)
155
+ """
156
+ if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"):
157
+ return "generate_key_issues"
158
+ if state["current_plan_step"] == len(state["store_plan"]) - 1:
159
+ return "generate_key_issues"
160
+ else:
161
+ return "detail_step"
162
+
163
+ def detail_or_data_retriever(state: State, config: ConfigSchema):
164
+ """
165
+ Detail the query to use for data retrieval or not
166
+ """
167
+ if config["configurable"].get("use_detailed_query"):
168
+ return "chatbot_detail"
169
+ else:
170
+ return "data_retriever"
171
+
172
+ def retrieve_or_process(state: State):
173
+ """
174
+ Process the retrieved docs or keep retrieving
175
+ """
176
+ if state['human_validated']:
177
+ return "process"
178
+ return "retrieve"
179
+ # while True:
180
+ # user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ")
181
+ # if user_input.lower() == "y":
182
+ # return "retrieve"
183
+ # if not user_input or user_input.lower() == "n":
184
+ # return "process"
185
+ # print("Please answer with 'y' or 'n'.\n")
186
+
187
+
188
+ def build_planner_graph(memory, config):
189
+ """
190
+ Builds the planner graph
191
+ """
192
+ graph_builder = StateGraph(State)
193
+
194
+ graph_doc_retriever = build_data_retriever_graph(memory)
195
+ graph_doc_processor = build_data_processor_graph(memory)
196
+ graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]])
197
+ graph_builder.add_node("validate", validate_node)
198
+ graph_builder.add_node("chatbot_detail", chatbot_llama)
199
+ graph_builder.add_node("parse", parse_plan)
200
+ graph_builder.add_node("detail_step", detail_step)
201
+ graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState)
202
+ graph_builder.add_node("human_validation", human_validation)
203
+ graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState)
204
+ graph_builder.add_node("concatenate_data", concatenate_data)
205
+ graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]])
206
+ graph_builder.add_node("generate_ki", generate_ki)
207
+ graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]])
208
+ graph_builder.add_node("detail_ki", detail_ki)
209
+ graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]])
210
+
211
+ graph_builder.add_edge("validate", "chatbot_planner")
212
+ graph_builder.add_edge("parse", "detail_step")
213
+
214
+
215
+ # graph_builder.add_edge("detail_step", "chatbot2")
216
+ graph_builder.add_edge("chatbot_detail", "data_retriever")
217
+ graph_builder.add_edge("data_retriever", "human_validation")
218
+
219
+
220
+ graph_builder.add_edge("data_processor", "concatenate_data")
221
+ graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
222
+ graph_builder.add_edge("generate_ki", "chatbot_ki")
223
+ graph_builder.add_edge("chatbot_ki", "detail_ki")
224
+ graph_builder.add_edge("detail_ki", "chatbot_final")
225
+ graph_builder.add_edge("chatbot_final", "__end__")
226
+
227
+ graph_builder.add_conditional_edges(
228
+ "detail_step",
229
+ detail_or_data_retriever,
230
+ {"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
231
+ )
232
+ graph_builder.add_conditional_edges(
233
+ "human_validation",
234
+ retrieve_or_process,
235
+ {"retrieve" : "data_retriever", "process" : "data_processor"}
236
+ )
237
+ graph_builder.add_conditional_edges(
238
+ "chatbot_planner",
239
+ validate_plan,
240
+ {"parse" : "parse", "validate": "validate"}
241
+ )
242
+ graph_builder.add_conditional_edges(
243
+ "chatbot_exec_step",
244
+ next_plan_step,
245
+ {"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
246
+ )
247
+
248
+ graph_builder.set_entry_point("chatbot_planner")
249
+ graph = graph_builder.compile(
250
+ checkpointer=memory,
251
+ interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"],
252
+ )
253
+ return graph
ki_gen/prompts.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts.prompt import PromptTemplate
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.messages import SystemMessage, HumanMessage
4
+ from ki_gen.utils import ConfigSchema
5
+
6
+ CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
7
+ Instructions:
8
+ Use only the provided relationship types and properties in the schema.
9
+ Do not use any other relationship types or properties that are not provided.
10
+ Schema:
11
+ {schema}
12
+
13
+
14
+ Concepts:
15
+ {concepts}
16
+
17
+
18
+ Concept names can ONLY be selected from the above list
19
+
20
+ Note: Do not include any explanations or apologies in your responses.
21
+ Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
22
+ Do not include any text except the generated Cypher statement.
23
+
24
+ The question is:
25
+ {question}"""
26
+ CYPHER_GENERATION_PROMPT = PromptTemplate(
27
+ input_variables=["schema", "question", "concepts"], template=CYPHER_GENERATION_TEMPLATE
28
+ )
29
+
30
+ CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
31
+ The information part contains the provided information that you must use to construct an answer.
32
+ The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
33
+ Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
34
+ Here is an example:
35
+
36
+ Question: Which managers own Neo4j stocks?
37
+ Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
38
+ Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
39
+
40
+ Follow this example when generating answers.
41
+ If the provided information is empty, say that you don't know the answer.
42
+ Information:
43
+ {context}
44
+
45
+ Question: {question}
46
+ Helpful Answer:"""
47
+ CYPHER_QA_PROMPT = PromptTemplate(
48
+ input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
49
+ )
50
+
51
+ PLAN_GEN_PROMPT = """System : You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement
52
+
53
+ System : Let's first understand the problem and devise a plan to solve the problem.
54
+ Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
55
+ Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
56
+ At the end of your plan, say '<END_OF_PLAN>'"""
57
+
58
+ PLAN_MODIFICATION_PROMPT = """You are a standardization expert working for 3GPP. You are given a specific technical requirement regarding the deployment of 5G services. Your goal is to specify NEW and INNOVATIVE Key Issues that could occur while trying to fulfill this requirement.
59
+ To achieve this goal we are going to follow this generic plan :
60
+
61
+ ###PLAN TEMPLATE###
62
+
63
+ Plan:
64
+
65
+ 1. **Understanding the Problem**: Gather information from existing specifications and standards to thoroughly understand the technical requirement. This should help you understand the key aspects of the problem.
66
+ 2. **Gather information about latest innovations** : Gather information about the latest innovations related to the problem by looking at the most relevant research papers.
67
+ 3. **Researching current challenges** : Research the current challenges in this area by looking at the existing similar key issues that have been identified by 3GPP.
68
+ 4. **Identifying NEW and INNOVATIVE Key Issues**: Based on the understanding of the problem and the current challenges, identify new and innovative key issues that could occur while trying to fulfill this requirement. These key issues should be relevant, significant, and not yet addressed by existing solutions.
69
+ 5. **Develop Detailed Descriptions for Each Key Issue**: For each identified key issue, provide a detailed description, including the specific challenges and areas requiring further study.
70
+
71
+ <END_OF_PLAN>
72
+
73
+ ###END OF PLAN TEMPLATE###
74
+
75
+ Let's and devise a plan to solve the problem by adapting the PLAN TEMPLATE.
76
+ Output the plan starting with the header 'Plan:' and then followed by a numbered list of steps.
77
+ Make the plan the minimum number of steps required to accurately provide the user with NEW and INNOVATIVE Key Issues related to the technical requirement.
78
+ At the end of your plan, say '<END_OF_PLAN>' """
79
+
80
+ CONCEPT_SELECTION_TEMPLATE = """Task: Select the most relevant topic to the user question
81
+ Instructions:
82
+ Select the most relevant Concept to the user's question.
83
+ Concepts can ONLY be selected from the list below.
84
+
85
+ Concepts:
86
+ {concepts}
87
+
88
+ Note: Do not include any explanations or apologies in your responses.
89
+ Do not include any text except the selected concept.
90
+
91
+ The question is:
92
+ {question}"""
93
+ CONCEPT_SELECTION_PROMPT = PromptTemplate(
94
+ input_variables=["concepts", "question"], template=CONCEPT_SELECTION_TEMPLATE
95
+ )
96
+
97
+ RELEVANT_CONCEPTS_TEMPLATE = """
98
+ ## CONCEPTS ##
99
+ {concepts}
100
+ ## END OF CONCEPTS ##
101
+
102
+ Select the 20 most relevant concepts to the user query.
103
+ Output your answer as a numbered list preceeded with the header 'Concepts:'.
104
+
105
+ User query :
106
+ {user_query}
107
+ """
108
+ RELEVANT_CONCEPTS_PROMPT = ChatPromptTemplate.from_messages([
109
+ ("human", RELEVANT_CONCEPTS_TEMPLATE)
110
+ ])
111
+
112
+ SUMMARIZER_TEMPLATE = """You are a 3GPP standardization expert.
113
+ Summarize the provided document in simple technical English for other experts in the field.
114
+
115
+ Document:
116
+ {document}"""
117
+ SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
118
+ ("system", SUMMARIZER_TEMPLATE)
119
+ ])
120
+
121
+
122
+ BINARY_GRADER_TEMPLATE = """You are a grader assessing relevance of a retrieved document to a user question. \n
123
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
124
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
125
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
126
+ BINARY_GRADER_PROMPT = ChatPromptTemplate.from_messages(
127
+ [
128
+ ("system", BINARY_GRADER_TEMPLATE),
129
+ ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
130
+ ]
131
+ )
132
+
133
+
134
+ SCORE_GRADER_TEMPLATE = """Grasp and understand both the query and the document before score generation.
135
+ Then, based on your understanding and analysis quantify the relevance between the document and the query.
136
+ Give the rationale before answering.
137
+ Ouput your answer as a score ranging between 0 (irrelevant document) and 1 (completely relevant document)"""
138
+
139
+ SCORE_GRADER_PROMPT = ChatPromptTemplate.from_messages(
140
+ [
141
+ ("system", SCORE_GRADER_TEMPLATE),
142
+ ("human", "Passage: \n\n {document} \n\n User query: {query}")
143
+ ]
144
+ )
145
+
146
+ def get_initial_prompt(config: ConfigSchema, user_query : str):
147
+ if config["configurable"].get("plan_method") == "generation":
148
+ prompt = PLAN_GEN_PROMPT
149
+ elif config["configurable"].get("plan_method") == "modification":
150
+ prompt = PLAN_MODIFICATION_PROMPT
151
+ else:
152
+ raise ValueError("Incorrect plan_method, should be 'generation' or 'modification'")
153
+
154
+ user_input = user_query or input("User :")
155
+ return {"messages" : [SystemMessage(content=prompt), HumanMessage(content=user_input)]}
ki_gen/utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import getpass
3
+ import html
4
+
5
+
6
+ from typing import Annotated, Union
7
+ from typing_extensions import TypedDict
8
+
9
+ from langchain_community.graphs import Neo4jGraph
10
+ from langchain_groq import ChatGroq
11
+ from langchain_openai import ChatOpenAI
12
+
13
+ from langgraph.checkpoint.sqlite import SqliteSaver
14
+ from langgraph.checkpoint import base
15
+ from langgraph.graph import add_messages
16
+
17
+ with SqliteSaver.from_conn_string(":memory:") as mem :
18
+ memory = mem
19
+
20
+
21
+ def format_df(df):
22
+ """
23
+ Used to display the generated plan in a nice format
24
+ Returns html code in a string
25
+ """
26
+ def format_cell(cell):
27
+ if isinstance(cell, str):
28
+ # Encode special characters, but preserve line breaks
29
+ return html.escape(cell).replace('\n', '<br>')
30
+ return cell
31
+ # Convert the DataFrame to HTML with custom CSS
32
+ formatted_df = df.map(format_cell)
33
+ html_table = formatted_df.to_html(escape=False, index=False)
34
+
35
+ # Add custom CSS to allow multiple lines and scrolling in cells
36
+ css = """
37
+ <style>
38
+ table {
39
+ border-collapse: collapse;
40
+ width: 100%;
41
+ }
42
+ th, td {
43
+ border: 1px solid black;
44
+ padding: 8px;
45
+ text-align: left;
46
+ vertical-align: top;
47
+ white-space: pre-wrap;
48
+ max-width: 300px;
49
+ max-height: 100px;
50
+ overflow-y: auto;
51
+ }
52
+ th {
53
+ background-color: #f2f2f2;
54
+ }
55
+ </style>
56
+ """
57
+
58
+ return css + html_table
59
+
60
+ def format_doc(doc: dict) -> str :
61
+ formatted_string = ""
62
+ for key in doc:
63
+ formatted_string += f"**{key}**: {doc[key]}\n"
64
+ return formatted_string
65
+
66
+
67
+
68
+ def _set_env(var: str, value: str = None):
69
+ if not os.environ.get(var):
70
+ if value:
71
+ os.environ[var] = value
72
+ else:
73
+ os.environ[var] = getpass.getpass(f"{var}: ")
74
+
75
+
76
+ def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
77
+ """
78
+ Initialize app with user api keys and sets up proxy settings
79
+ """
80
+ _set_env("GROQ_API_KEY", value=groq_key)
81
+ _set_env("LANGSMITH_API_KEY", value=langsmith_key)
82
+ _set_env("OPENAI_API_KEY", value=openai_key)
83
+ os.environ["LANGSMITH_TRACING_V2"] = "true"
84
+ os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
85
+ os.environ["http_proxy"] = "185.46.212.98:80"
86
+ os.environ["https_proxy"] = "185.46.212.98:80"
87
+ os.environ["NO_PROXY"] = "thalescloud.io"
88
+
89
+ def clear_memory(memory, thread_id: str) -> None:
90
+ """
91
+ Clears checkpointer state for a given thread_id, broken for now
92
+ TODO : fix this
93
+ """
94
+ with SqliteSaver.from_conn_string(":memory:") as mem :
95
+ memory = mem
96
+ checkpoint = base.empty_checkpoint()
97
+ memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
98
+
99
+ def get_model(model : str = "mixtral-8x7b-32768"):
100
+ """
101
+ Wrapper to return the correct llm object depending on the 'model' param
102
+ """
103
+ if model == "gpt-4o":
104
+ llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
105
+ else:
106
+ llm = ChatGroq(model=model)
107
+ return llm
108
+
109
+
110
+ class ConfigSchema(TypedDict):
111
+ graph: Neo4jGraph
112
+ plan_method: str
113
+ use_detailed_query: bool
114
+
115
+ class State(TypedDict):
116
+ messages : Annotated[list, add_messages]
117
+ store_plan : list[str]
118
+ current_plan_step : int
119
+ valid_docs : list[str]
120
+
121
+ class DocRetrieverState(TypedDict):
122
+ messages: Annotated[list, add_messages]
123
+ query: str
124
+ docs: list[dict]
125
+ cyphers: list[str]
126
+ current_plan_step : int
127
+ valid_docs: list[Union[str, dict]]
128
+
129
+ class HumanValidationState(TypedDict):
130
+ human_validated : bool
131
+ process_steps : list[str]
132
+
133
+ def update_doc_history(left : list | None, right : list | None) -> list:
134
+ """
135
+ Reducer for the 'docs_in_processing' field.
136
+ Doesn't work currently because of bad handlinf of duplicates
137
+ TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
138
+ """
139
+ if not left:
140
+ # This shouldn't happen
141
+ left = [[]]
142
+ if not right:
143
+ right = []
144
+
145
+ for i in range(len(right)):
146
+ left[i].append(right[i])
147
+ return left
148
+
149
+
150
+ class DocProcessorState(TypedDict):
151
+ valid_docs : list[Union[str, dict]]
152
+ docs_in_processing : list
153
+ process_steps : list[Union[str,dict]]
154
+ current_process_step : int
155
+