|
import gradio as gr |
|
from langchain_community.graphs import Neo4jGraph |
|
import pandas as pd |
|
import json |
|
|
|
from ki_gen.planner import build_planner_graph |
|
from ki_gen.utils import clear_memory, init_app, format_df, memory |
|
from ki_gen.prompts import get_initial_prompt |
|
|
|
MAX_PROCESSING_STEPS = 10 |
|
|
|
from neo4j import GraphDatabase |
|
|
|
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io" |
|
AUTH = ("neo4j", "P6zQScbmyWivYeVZ84BniNjOCxu1D5Akw1IRC1SLKx8") |
|
|
|
with GraphDatabase.driver(NEO4J_URI, auth=AUTH) as driver: |
|
driver.verify_connectivity() |
|
print(driver.verify_connectivity()) |
|
print("i guess its gut") |
|
|
|
def start_inference(data): |
|
""" |
|
Starts plan generation with user_query as input which gets displayed after |
|
""" |
|
config = data[config_state] |
|
init_app( |
|
openai_key=data[openai_api_key], |
|
groq_key=data[groq_api_key], |
|
langsmith_key=data[langsmith_api_key] |
|
) |
|
|
|
|
|
|
|
|
|
graph = build_planner_graph(memory, config["configurable"]) |
|
with open("images/graph_png.png", "wb") as f: |
|
f.write(graph.get_graph(xray=1).draw_mermaid_png()) |
|
|
|
print("here !") |
|
for event in graph.stream(get_initial_prompt(config, data[user_query]), config, stream_mode="values"): |
|
if "messages" in event: |
|
event["messages"][-1].pretty_print() |
|
|
|
state = graph.get_state(config) |
|
steps = [i for i in range(1,len(state.values['store_plan'])+1)] |
|
df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']}) |
|
return [df, graph] |
|
|
|
def update_display(df): |
|
""" |
|
Displays the df after it has been generated |
|
""" |
|
formatted_html = format_df(df) |
|
return { |
|
plan_display : gr.update(visible=True, value = formatted_html), |
|
select_step_to_modify : gr.update(visible=True, value=0), |
|
enter_new_step : gr.update(visible=True), |
|
submit_new_step : gr.update(visible=True), |
|
continue_inference_btn : gr.update(visible=True) |
|
} |
|
|
|
def format_docs(docs: list[dict]): |
|
formatted_results = "" |
|
for i, doc in enumerate(docs): |
|
formatted_results += f"\n### Document {i}\n" |
|
for key in doc: |
|
formatted_results += f"**{key}**: {doc[key]}\n" |
|
return formatted_results |
|
|
|
def continue_inference(data): |
|
""" |
|
Proceeds to next plan step |
|
""" |
|
graph = data[graph_state] |
|
config = data[config_state] |
|
|
|
for event in graph.stream(None, config, stream_mode="values"): |
|
if "messages" in event: |
|
event["messages"][-1].pretty_print() |
|
|
|
snapshot = graph.get_state(config) |
|
print(f"DEBUG INFO : next : {snapshot.next}") |
|
print(f"DEBUG INFO ++ L.75: {snapshot}") |
|
|
|
if snapshot.next and snapshot.next[0] == "human_validation": |
|
return { |
|
continue_inference_btn : gr.update(visible=False), |
|
graph_state : graph, |
|
retrieve_more_docs_btn : gr.update(visible=True), |
|
continue_to_processing_btn : gr.update(visible=True), |
|
human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"), |
|
retrieved_docs_state : snapshot.values['valid_docs'] |
|
} |
|
|
|
return { |
|
plan_result : snapshot.values["messages"][-1].content, |
|
graph_state : graph, |
|
continue_inference_btn : gr.update(visible=False) |
|
} |
|
|
|
def continue_to_processing(): |
|
""" |
|
Continue to doc processing configuration |
|
""" |
|
return { |
|
retrieve_more_docs_btn : gr.update(visible=False), |
|
continue_to_processing_btn : gr.update(visible=False), |
|
human_validation_title : gr.update(visible=False), |
|
process_data_btn : gr.update(visible=True), |
|
process_steps_nb : gr.update(visible=True), |
|
process_steps_title : gr.update(visible=True) |
|
} |
|
|
|
def retrieve_more_docs(data): |
|
""" |
|
Restart doc retrieval |
|
For now we simply regenerate the cypher, it may be different because temperature != 0 |
|
""" |
|
graph = data[graph_state] |
|
config = data[config_state] |
|
graph.update_state(config, {'human_validated' : False}, as_node="human_validation") |
|
|
|
for event in graph.stream(None, config, stream_mode="values"): |
|
if "messages" in event: |
|
event["messages"][-1].pretty_print() |
|
|
|
snapshot = graph.get_state(config) |
|
print(f"DEBUG INFO : next : {snapshot.next}") |
|
print(f"DEBUG INFO ++ L.121: {snapshot}") |
|
|
|
return { |
|
graph_state : graph, |
|
human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"), |
|
retrieved_docs_display : format_docs(snapshot.values['valid_docs']) |
|
} |
|
|
|
def execute_processing(*args): |
|
""" |
|
Execute doc processing |
|
Args are passed as a list and not a dict for syntax convenience |
|
""" |
|
graph = args[-2] |
|
config = args[-1] |
|
nb_process_steps = args[-3] |
|
|
|
process_steps = [] |
|
for i in range (nb_process_steps): |
|
if args[i] == "custom": |
|
process_steps.append({"prompt" : args[nb_process_steps + i], "context" : args[2*nb_process_steps + i], "processing_model" : args[3*nb_process_steps + i]}) |
|
else: |
|
process_steps.append(args[i]) |
|
|
|
graph.update_state(config, {'human_validated' : True, 'process_steps' : process_steps}, as_node="human_validation") |
|
|
|
for event in graph.stream(None, config, stream_mode="values"): |
|
if "messages" in event: |
|
event["messages"][-1].pretty_print() |
|
|
|
snapshot = graph.get_state(config) |
|
print(f"DEBUG INFO : next : {snapshot.next}") |
|
print(f"DEBUG INFO ++ L.153: {snapshot}") |
|
|
|
return { |
|
plan_result : snapshot.values["messages"][-1].content, |
|
processed_docs_state : snapshot.values["valid_docs"], |
|
graph_state : graph, |
|
continue_inference_btn : gr.update(visible=True), |
|
process_steps_nb : gr.update(value=0, visible=False), |
|
process_steps_title : gr.update(visible=False), |
|
process_data_btn : gr.update(visible=False), |
|
} |
|
|
|
|
|
|
|
def update_config_display(): |
|
""" |
|
Called after loading the config.json file |
|
TODO : allow the user to specify a path to the config file |
|
""" |
|
with open("config.json", "r") as config_file: |
|
config = json.load(config_file) |
|
|
|
return { |
|
main_llm : config["main_llm"], |
|
plan_method : config["plan_method"], |
|
use_detailed_query : config["use_detailed_query"], |
|
cypher_gen_method : config["cypher_gen_method"], |
|
validate_cypher : config["validate_cypher"], |
|
summarization_model : config["summarize_model"], |
|
eval_method : config["eval_method"], |
|
eval_threshold : config["eval_threshold"], |
|
max_docs : config["max_docs"], |
|
compression_method : config["compression_method"], |
|
compress_rate : config["compress_rate"], |
|
force_tokens : config["force_tokens"], |
|
eval_model : config["eval_model"], |
|
srv_addr : config["graph"]["address"], |
|
srv_usr : config["graph"]["username"], |
|
srv_pwd : config["graph"]["password"], |
|
openai_api_key : config["openai_api_key"], |
|
groq_api_key : config["groq_api_key"], |
|
langsmith_api_key : config["langsmith_api_key"] |
|
} |
|
|
|
|
|
def build_config(data): |
|
""" |
|
Build the config variable using the values inputted by the user |
|
""" |
|
config = {} |
|
config["main_llm"] = data[main_llm] |
|
config["plan_method"] = data[plan_method] |
|
config["use_detailed_query"] = data[use_detailed_query] |
|
config["cypher_gen_method"] = data[cypher_gen_method] |
|
config["validate_cypher"] = data[validate_cypher] |
|
config["summarize_model"] = data[summarization_model] |
|
config["eval_method"] = data[eval_method] |
|
config["eval_threshold"] = data[eval_threshold] |
|
config["max_docs"] = data[max_docs] |
|
config["compression_method"] = data[compression_method] |
|
config["compress_rate"] = data[compress_rate] |
|
config["force_tokens"] = data[force_tokens] |
|
config["eval_model"] = data[eval_model] |
|
config["thread_id"] = "3" |
|
try: |
|
neograph = Neo4jGraph(url=data[srv_addr], username=data[srv_usr], password=data[srv_pwd]) |
|
config["graph"] = neograph |
|
except Exception as e: |
|
raise gr.Error(f"Error when configuring the neograph server : {e}", duration=5) |
|
gr.Info("Succesfully updated configuration !", duration=5) |
|
return {"configurable" : config} |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Config"): |
|
|
|
|
|
|
|
gr.Markdown("## Config options setup") |
|
|
|
gr.Markdown("### API Keys") |
|
|
|
with gr.Row(): |
|
openai_api_key = gr.Textbox( |
|
label="OpenAI API Key", |
|
type="password" |
|
) |
|
|
|
groq_api_key = gr.Textbox( |
|
label="Groq API Key", |
|
type='password' |
|
) |
|
|
|
langsmith_api_key = gr.Textbox( |
|
label="LangSmith API Key", |
|
type="password" |
|
) |
|
|
|
gr.Markdown('### Planner options') |
|
with gr.Row(): |
|
main_llm = gr.Dropdown( |
|
choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768"], |
|
label="Main LLM", |
|
info="Choose the LLM which will perform the generation", |
|
value="gpt-4o" |
|
) |
|
with gr.Column(scale=1, min_width=600): |
|
plan_method = gr.Dropdown( |
|
choices=["generation", "modification"], |
|
label="Planning method", |
|
info="Choose how the main LLM will generate its plan", |
|
value="modification" |
|
) |
|
use_detailed_query = gr.Checkbox( |
|
label="Detail each plan step", |
|
info="Detail each plan step before passing it for data query" |
|
) |
|
|
|
gr.Markdown("### Data query options") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
|
|
|
|
srv_addr = gr.Textbox( |
|
label="Neo4j server address", |
|
placeholder="localhost:7687" |
|
) |
|
srv_usr = gr.Textbox( |
|
label="Neo4j username", |
|
placeholder="neo4j" |
|
) |
|
srv_pwd = gr.Textbox( |
|
label="Neo4j password", |
|
placeholder="<Password>" |
|
) |
|
|
|
with gr.Column(scale=1, min_width=300): |
|
cypher_gen_method = gr.Dropdown( |
|
choices=["auto", "guided"], |
|
label="Cypher generation method", |
|
) |
|
validate_cypher = gr.Checkbox( |
|
label="Validate cypher using graph Schema" |
|
) |
|
|
|
summarization_model = gr.Dropdown( |
|
choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768", "llama3-70b-8192"], |
|
label="Summarization LLM", |
|
info="Choose the LLM which will perform the summaries" |
|
) |
|
|
|
with gr.Column(scale=1, min_width=300): |
|
eval_method = gr.Dropdown( |
|
choices=["binary", "score"], |
|
label="Retrieved docs evaluation method", |
|
info="Evaluation method of retrieved docs" |
|
) |
|
|
|
eval_model = gr.Dropdown( |
|
choices = ["gpt-4o", "mixtral-8x7b-32768"], |
|
label = "Evaluation model", |
|
info = "The LLM to use to evaluate the relevance of retrieved docs", |
|
value = "mixtral-8x7b-32768" |
|
) |
|
|
|
eval_threshold = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.7, |
|
label="Eval threshold", |
|
info="Score above which a doc is considered relevant", |
|
step=0.01, |
|
visible=False |
|
) |
|
|
|
def eval_method_changed(selection): |
|
if selection == "score": |
|
return gr.update(visible=True) |
|
return gr.update(visible=False) |
|
eval_method.change(eval_method_changed, inputs=eval_method, outputs=eval_threshold) |
|
|
|
max_docs= gr.Slider( |
|
minimum=0, |
|
maximum = 30, |
|
value = 15, |
|
label="Max docs", |
|
info="Maximum number of docs to be retrieved at each query", |
|
step=0.01 |
|
) |
|
|
|
with gr.Column(scale=1, min_width=300): |
|
compression_method = gr.Dropdown( |
|
choices=["llm_lingua2", "llm_lingua"], |
|
label="Compression method", |
|
value="llm_lingua2" |
|
) |
|
|
|
with gr.Row(): |
|
|
|
|
|
compress_rate = gr.Slider( |
|
minimum = 0, |
|
maximum = 1, |
|
value = 0.33, |
|
label="Compression rate", |
|
info="Compression rate", |
|
step = 0.01 |
|
) |
|
|
|
|
|
force_tokens = gr.CheckboxGroup( |
|
choices=['\n', '?', '.', '!', ','], |
|
value=[], |
|
label="Force tokens", |
|
info="Tokens to keep during compression", |
|
) |
|
|
|
with gr.Row(): |
|
btn_update_config = gr.Button(value="Update config") |
|
load_config_json = gr.Button(value="Load config from JSON") |
|
|
|
with gr.Row(): |
|
debug_info = gr.Button(value="Print debug info") |
|
|
|
config_state = gr.State(value={}) |
|
|
|
|
|
btn_update_config.click( |
|
build_config, |
|
inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \ |
|
compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs}, |
|
outputs=config_state |
|
) |
|
load_config_json.click( |
|
update_config_display, |
|
outputs={main_llm, plan_method, use_detailed_query, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, \ |
|
max_docs, compress_rate, compression_method, force_tokens, eval_model, srv_addr, srv_usr, srv_pwd, openai_api_key, langsmith_api_key, groq_api_key} |
|
).then( |
|
build_config, |
|
inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \ |
|
compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs}, |
|
outputs=config_state |
|
) |
|
|
|
|
|
debug_info.click(lambda x : print(x), inputs=config_state) |
|
|
|
with gr.Tab("Inference"): |
|
|
|
|
|
graph_state = gr.State() |
|
user_query = gr.Textbox(label = "Your query") |
|
launch_inference = gr.Button(value="Generate plan") |
|
|
|
with gr.Row(): |
|
dataframe_plan = gr.Dataframe(visible = False) |
|
plan_display = gr.HTML(visible = False, label="Generated plan") |
|
|
|
with gr.Column(): |
|
|
|
|
|
|
|
with gr.Row(): |
|
select_step_to_modify = gr.Number(visible= False, label="Select a plan step to modify", value=0) |
|
submit_new_step = gr.Button(visible = False, value="Submit new step") |
|
enter_new_step = gr.Textbox(visible=False, label="Modify the plan step") |
|
|
|
with gr.Row(): |
|
human_validation_title = gr.Markdown(visible=False) |
|
retrieve_more_docs_btn = gr.Button(value="Retrieve more docs", visible=False) |
|
continue_to_processing_btn = gr.Button(value="Proceed to data processing", visible=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
process_steps_title = gr.Markdown("#### Data processing steps", visible=False) |
|
process_steps_nb = gr.Number(label="Number of processing steps", value = 0, precision=0, step = 1, visible=False) |
|
|
|
def get_process_step_names(): |
|
return ["summarize", "compress", "custom"] |
|
|
|
|
|
|
|
|
|
@gr.render(inputs=process_steps_nb) |
|
def processing(nb): |
|
with gr.Row(): |
|
process_step_names = get_process_step_names() |
|
dropdowns = [] |
|
textboxes = [] |
|
usable_elements = [] |
|
processing_models = [] |
|
for i in range(nb): |
|
with gr.Column(): |
|
dropdown = gr.Dropdown(key = f"d{i}", choices=process_step_names, label=f"Data processing step {i+1}") |
|
dropdowns.append(dropdown) |
|
|
|
textbox = gr.Textbox( |
|
key = f"t{i}", |
|
value="", |
|
placeholder="Your custom prompt", |
|
visible=True, min_width=300) |
|
textboxes.append(textbox) |
|
|
|
usable_element = gr.Dropdown( |
|
key = f"u{i}", |
|
choices = [(j) for j in range(i+1)], |
|
label="Elements passed to the LLM for this process step", |
|
multiselect=True, |
|
) |
|
usable_elements.append(usable_element) |
|
|
|
processing_model = gr.Dropdown( |
|
key = f"m{i}", |
|
label="The LLM that will execute this step", |
|
visible=True, |
|
choices=["gpt-4o", "mixtral-8x7b-32768", "llama3-70b-8182"] |
|
) |
|
processing_models.append(processing_model) |
|
|
|
dropdown.change( |
|
fn=lambda process_name : [gr.update(visible=(process_name=="custom")), gr.update(visible=(process_name=='custom')), gr.update(visible=(process_name=='custom'))], |
|
inputs=dropdown, |
|
outputs=[textbox, usable_element, processing_model] |
|
) |
|
|
|
process_data_btn.click( |
|
execute_processing, |
|
inputs= dropdowns + textboxes + usable_elements + processing_models + [process_steps_nb, graph_state, config_state], |
|
outputs={plan_result, processed_docs_state, graph_state, continue_inference_btn, process_steps_nb, process_steps_title, process_data_btn} |
|
) |
|
|
|
process_data_btn = gr.Button(value="Process retrieved docs", visible=False) |
|
|
|
continue_inference_btn = gr.Button(value="Proceed to next plan step", visible=False) |
|
plan_result = gr.Markdown(visible = True, label="Result of last plan step") |
|
|
|
with gr.Tab("Retrieved Docs"): |
|
retrieved_docs_state = gr.State([]) |
|
with gr.Row(): |
|
gr.Markdown("# Retrieved Docs") |
|
retrieved_docs_btn = gr.Button("Display retrieved docs") |
|
retrieved_docs_display = gr.Markdown() |
|
|
|
processed_docs_state = gr.State([]) |
|
with gr.Row(): |
|
gr.Markdown("# Processed Docs") |
|
processed_docs_btn = gr.Button("Display processed docs") |
|
processed_docs_display = gr.Markdown() |
|
|
|
continue_inference_btn.click( |
|
continue_inference, |
|
inputs={graph_state, config_state}, |
|
outputs={continue_inference_btn, graph_state, retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, plan_result, retrieved_docs_state} |
|
) |
|
|
|
launch_inference.click( |
|
start_inference, |
|
inputs={config_state, user_query, openai_api_key, groq_api_key, langsmith_api_key}, |
|
outputs=[dataframe_plan, graph_state] |
|
).then( |
|
update_display, |
|
inputs=dataframe_plan, |
|
outputs={plan_display, select_step_to_modify, enter_new_step, submit_new_step, continue_inference_btn} |
|
) |
|
|
|
retrieve_more_docs_btn.click( |
|
retrieve_more_docs, |
|
inputs={graph_state, config_state}, |
|
outputs={graph_state, human_validation_title, retrieved_docs_display} |
|
) |
|
continue_to_processing_btn.click( |
|
continue_to_processing, |
|
outputs={retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, process_data_btn, process_steps_nb, process_steps_title} |
|
) |
|
retrieved_docs_btn.click( |
|
fn=lambda docs : format_docs(docs), |
|
inputs=retrieved_docs_state, |
|
outputs=retrieved_docs_display |
|
) |
|
processed_docs_btn.click( |
|
fn=lambda docs : format_docs(docs), |
|
inputs=processed_docs_state, |
|
outputs=processed_docs_display |
|
) |
|
|
|
|
|
test_process_steps = gr.Button(value="Test process steps") |
|
test_process_steps.click( |
|
lambda : [gr.update(visible = True), gr.update(visible=True)], |
|
outputs=[process_steps_nb, process_steps_title] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|