bjhk / app.py
heymenn's picture
Update app.py
619469d verified
raw
history blame
22.4 kB
import gradio as gr
from langchain_community.graphs import Neo4jGraph
import pandas as pd
import json
from ki_gen.planner import build_planner_graph
from ki_gen.utils import clear_memory, init_app, format_df, memory
from ki_gen.prompts import get_initial_prompt
MAX_PROCESSING_STEPS = 10
from neo4j import GraphDatabase
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
AUTH = ("neo4j", "P6zQScbmyWivYeVZ84BniNjOCxu1D5Akw1IRC1SLKx8")
with GraphDatabase.driver(NEO4J_URI, auth=AUTH) as driver:
driver.verify_connectivity()
print(driver.verify_connectivity())
print("i guess its gut")
def start_inference(data):
"""
Starts plan generation with user_query as input which gets displayed after
"""
config = data[config_state]
init_app(
openai_key=data[openai_api_key],
groq_key=data[groq_api_key],
langsmith_key=data[langsmith_api_key]
)
#TO DO function : clear_memory
#clear_memory(memory, config["configurable"].get("thread_id"))
graph = build_planner_graph(memory, config["configurable"])
with open("images/graph_png.png", "wb") as f:
f.write(graph.get_graph(xray=1).draw_mermaid_png())
print("here !")
for event in graph.stream(get_initial_prompt(config, data[user_query]), config, stream_mode="values"):
if "messages" in event:
event["messages"][-1].pretty_print()
state = graph.get_state(config)
steps = [i for i in range(1,len(state.values['store_plan'])+1)]
df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']})
return [df, graph]
def update_display(df):
"""
Displays the df after it has been generated
"""
formatted_html = format_df(df)
return {
plan_display : gr.update(visible=True, value = formatted_html),
select_step_to_modify : gr.update(visible=True, value=0),
enter_new_step : gr.update(visible=True),
submit_new_step : gr.update(visible=True),
continue_inference_btn : gr.update(visible=True)
}
def format_docs(docs: list[dict]):
formatted_results = ""
for i, doc in enumerate(docs):
formatted_results += f"\n### Document {i}\n"
for key in doc:
formatted_results += f"**{key}**: {doc[key]}\n"
return formatted_results
def continue_inference(data):
"""
Proceeds to next plan step
"""
graph = data[graph_state]
config = data[config_state]
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event:
event["messages"][-1].pretty_print()
snapshot = graph.get_state(config)
print(f"DEBUG INFO : next : {snapshot.next}")
print(f"DEBUG INFO ++ L.75: {snapshot}")
if snapshot.next and snapshot.next[0] == "human_validation":
return {
continue_inference_btn : gr.update(visible=False),
graph_state : graph,
retrieve_more_docs_btn : gr.update(visible=True),
continue_to_processing_btn : gr.update(visible=True),
human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"),
retrieved_docs_state : snapshot.values['valid_docs']
}
return {
plan_result : snapshot.values["messages"][-1].content,
graph_state : graph,
continue_inference_btn : gr.update(visible=False)
}
def continue_to_processing():
"""
Continue to doc processing configuration
"""
return {
retrieve_more_docs_btn : gr.update(visible=False),
continue_to_processing_btn : gr.update(visible=False),
human_validation_title : gr.update(visible=False),
process_data_btn : gr.update(visible=True),
process_steps_nb : gr.update(visible=True),
process_steps_title : gr.update(visible=True)
}
def retrieve_more_docs(data):
"""
Restart doc retrieval
For now we simply regenerate the cypher, it may be different because temperature != 0
"""
graph = data[graph_state]
config = data[config_state]
graph.update_state(config, {'human_validated' : False}, as_node="human_validation")
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event:
event["messages"][-1].pretty_print()
snapshot = graph.get_state(config)
print(f"DEBUG INFO : next : {snapshot.next}")
print(f"DEBUG INFO ++ L.121: {snapshot}")
return {
graph_state : graph,
human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"),
retrieved_docs_display : format_docs(snapshot.values['valid_docs'])
}
def execute_processing(*args):
"""
Execute doc processing
Args are passed as a list and not a dict for syntax convenience
"""
graph = args[-2]
config = args[-1]
nb_process_steps = args[-3]
process_steps = []
for i in range (nb_process_steps):
if args[i] == "custom":
process_steps.append({"prompt" : args[nb_process_steps + i], "context" : args[2*nb_process_steps + i], "processing_model" : args[3*nb_process_steps + i]})
else:
process_steps.append(args[i])
graph.update_state(config, {'human_validated' : True, 'process_steps' : process_steps}, as_node="human_validation")
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event:
event["messages"][-1].pretty_print()
snapshot = graph.get_state(config)
print(f"DEBUG INFO : next : {snapshot.next}")
print(f"DEBUG INFO ++ L.153: {snapshot}")
return {
plan_result : snapshot.values["messages"][-1].content,
processed_docs_state : snapshot.values["valid_docs"],
graph_state : graph,
continue_inference_btn : gr.update(visible=True),
process_steps_nb : gr.update(value=0, visible=False),
process_steps_title : gr.update(visible=False),
process_data_btn : gr.update(visible=False),
}
def update_config_display():
"""
Called after loading the config.json file
TODO : allow the user to specify a path to the config file
"""
with open("config.json", "r") as config_file:
config = json.load(config_file)
return {
main_llm : config["main_llm"],
plan_method : config["plan_method"],
use_detailed_query : config["use_detailed_query"],
cypher_gen_method : config["cypher_gen_method"],
validate_cypher : config["validate_cypher"],
summarization_model : config["summarize_model"],
eval_method : config["eval_method"],
eval_threshold : config["eval_threshold"],
max_docs : config["max_docs"],
compression_method : config["compression_method"],
compress_rate : config["compress_rate"],
force_tokens : config["force_tokens"],
eval_model : config["eval_model"],
srv_addr : config["graph"]["address"],
srv_usr : config["graph"]["username"],
srv_pwd : config["graph"]["password"],
openai_api_key : config["openai_api_key"],
groq_api_key : config["groq_api_key"],
langsmith_api_key : config["langsmith_api_key"]
}
def build_config(data):
"""
Build the config variable using the values inputted by the user
"""
config = {}
config["main_llm"] = data[main_llm]
config["plan_method"] = data[plan_method]
config["use_detailed_query"] = data[use_detailed_query]
config["cypher_gen_method"] = data[cypher_gen_method]
config["validate_cypher"] = data[validate_cypher]
config["summarize_model"] = data[summarization_model]
config["eval_method"] = data[eval_method]
config["eval_threshold"] = data[eval_threshold]
config["max_docs"] = data[max_docs]
config["compression_method"] = data[compression_method]
config["compress_rate"] = data[compress_rate]
config["force_tokens"] = data[force_tokens]
config["eval_model"] = data[eval_model]
config["thread_id"] = "3"
try:
neograph = Neo4jGraph(url=data[srv_addr], username=data[srv_usr], password=data[srv_pwd])
config["graph"] = neograph
except Exception as e:
raise gr.Error(f"Error when configuring the neograph server : {e}", duration=5)
gr.Info("Succesfully updated configuration !", duration=5)
return {"configurable" : config}
with gr.Blocks() as demo:
with gr.Tab("Config"):
### The config tab
gr.Markdown("## Config options setup")
gr.Markdown("### API Keys")
with gr.Row():
openai_api_key = gr.Textbox(
label="OpenAI API Key",
type="password"
)
groq_api_key = gr.Textbox(
label="Groq API Key",
type='password'
)
langsmith_api_key = gr.Textbox(
label="LangSmith API Key",
type="password"
)
gr.Markdown('### Planner options')
with gr.Row():
main_llm = gr.Dropdown(
choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768"],
label="Main LLM",
info="Choose the LLM which will perform the generation",
value="gpt-4o"
)
with gr.Column(scale=1, min_width=600):
plan_method = gr.Dropdown(
choices=["generation", "modification"],
label="Planning method",
info="Choose how the main LLM will generate its plan",
value="modification"
)
use_detailed_query = gr.Checkbox(
label="Detail each plan step",
info="Detail each plan step before passing it for data query"
)
gr.Markdown("### Data query options")
# The options for the data processor
# TODO : remove the options for summarize and compress and let the user choose them when specifying processing steps
# (similarly to what is done for custom processing step)
with gr.Row():
with gr.Column(scale=1, min_width=300):
# Neo4j Server parameters
srv_addr = gr.Textbox(
label="Neo4j server address",
placeholder="localhost:7687"
)
srv_usr = gr.Textbox(
label="Neo4j username",
placeholder="neo4j"
)
srv_pwd = gr.Textbox(
label="Neo4j password",
placeholder="<Password>"
)
with gr.Column(scale=1, min_width=300):
cypher_gen_method = gr.Dropdown(
choices=["auto", "guided"],
label="Cypher generation method",
)
validate_cypher = gr.Checkbox(
label="Validate cypher using graph Schema"
)
summarization_model = gr.Dropdown(
choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768", "llama3-70b-8192"],
label="Summarization LLM",
info="Choose the LLM which will perform the summaries"
)
with gr.Column(scale=1, min_width=300):
eval_method = gr.Dropdown(
choices=["binary", "score"],
label="Retrieved docs evaluation method",
info="Evaluation method of retrieved docs"
)
eval_model = gr.Dropdown(
choices = ["gpt-4o", "mixtral-8x7b-32768"],
label = "Evaluation model",
info = "The LLM to use to evaluate the relevance of retrieved docs",
value = "mixtral-8x7b-32768"
)
eval_threshold = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
label="Eval threshold",
info="Score above which a doc is considered relevant",
step=0.01,
visible=False
)
def eval_method_changed(selection):
if selection == "score":
return gr.update(visible=True)
return gr.update(visible=False)
eval_method.change(eval_method_changed, inputs=eval_method, outputs=eval_threshold)
max_docs= gr.Slider(
minimum=0,
maximum = 30,
value = 15,
label="Max docs",
info="Maximum number of docs to be retrieved at each query",
step=0.01
)
with gr.Column(scale=1, min_width=300):
compression_method = gr.Dropdown(
choices=["llm_lingua2", "llm_lingua"],
label="Compression method",
value="llm_lingua2"
)
with gr.Row():
# Add compression rate configuration with a gr.slider
compress_rate = gr.Slider(
minimum = 0,
maximum = 1,
value = 0.33,
label="Compression rate",
info="Compression rate",
step = 0.01
)
# Add gr.CheckboxGroup to choose force_tokens
force_tokens = gr.CheckboxGroup(
choices=['\n', '?', '.', '!', ','],
value=[],
label="Force tokens",
info="Tokens to keep during compression",
)
with gr.Row():
btn_update_config = gr.Button(value="Update config")
load_config_json = gr.Button(value="Load config from JSON")
with gr.Row():
debug_info = gr.Button(value="Print debug info")
config_state = gr.State(value={})
btn_update_config.click(
build_config,
inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \
compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs},
outputs=config_state
)
load_config_json.click(
update_config_display,
outputs={main_llm, plan_method, use_detailed_query, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, \
max_docs, compress_rate, compression_method, force_tokens, eval_model, srv_addr, srv_usr, srv_pwd, openai_api_key, langsmith_api_key, groq_api_key}
).then(
build_config,
inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \
compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs},
outputs=config_state
)
# Print config variable in the terminal
debug_info.click(lambda x : print(x), inputs=config_state)
with gr.Tab("Inference"):
### Inference tab
graph_state = gr.State()
user_query = gr.Textbox(label = "Your query")
launch_inference = gr.Button(value="Generate plan")
with gr.Row():
dataframe_plan = gr.Dataframe(visible = False)
plan_display = gr.HTML(visible = False, label="Generated plan")
with gr.Column():
# Lets the user modify steps of the plan. Underlying logic not implemented yet
# TODO : implement this
with gr.Row():
select_step_to_modify = gr.Number(visible= False, label="Select a plan step to modify", value=0)
submit_new_step = gr.Button(visible = False, value="Submit new step")
enter_new_step = gr.Textbox(visible=False, label="Modify the plan step")
with gr.Row():
human_validation_title = gr.Markdown(visible=False)
retrieve_more_docs_btn = gr.Button(value="Retrieve more docs", visible=False)
continue_to_processing_btn = gr.Button(value="Proceed to data processing", visible=False)
with gr.Row():
with gr.Column():
process_steps_title = gr.Markdown("#### Data processing steps", visible=False)
process_steps_nb = gr.Number(label="Number of processing steps", value = 0, precision=0, step = 1, visible=False)
def get_process_step_names():
return ["summarize", "compress", "custom"]
# The gr.render decorator allows the code inside the following function to be rerun everytime the 'inputs' variable is modified
# /!\ All event listeners that use variables defined inside a gr.render function must be defined inside that same function
# ref : https://www.gradio.app/docs/gradio/render
@gr.render(inputs=process_steps_nb)
def processing(nb):
with gr.Row():
process_step_names = get_process_step_names()
dropdowns = []
textboxes = []
usable_elements = []
processing_models = []
for i in range(nb):
with gr.Column():
dropdown = gr.Dropdown(key = f"d{i}", choices=process_step_names, label=f"Data processing step {i+1}")
dropdowns.append(dropdown)
textbox = gr.Textbox(
key = f"t{i}",
value="",
placeholder="Your custom prompt",
visible=True, min_width=300)
textboxes.append(textbox)
usable_element = gr.Dropdown(
key = f"u{i}",
choices = [(j) for j in range(i+1)],
label="Elements passed to the LLM for this process step",
multiselect=True,
)
usable_elements.append(usable_element)
processing_model = gr.Dropdown(
key = f"m{i}",
label="The LLM that will execute this step",
visible=True,
choices=["gpt-4o", "mixtral-8x7b-32768", "llama3-70b-8182"]
)
processing_models.append(processing_model)
dropdown.change(
fn=lambda process_name : [gr.update(visible=(process_name=="custom")), gr.update(visible=(process_name=='custom')), gr.update(visible=(process_name=='custom'))],
inputs=dropdown,
outputs=[textbox, usable_element, processing_model]
)
process_data_btn.click(
execute_processing,
inputs= dropdowns + textboxes + usable_elements + processing_models + [process_steps_nb, graph_state, config_state],
outputs={plan_result, processed_docs_state, graph_state, continue_inference_btn, process_steps_nb, process_steps_title, process_data_btn}
)
process_data_btn = gr.Button(value="Process retrieved docs", visible=False)
continue_inference_btn = gr.Button(value="Proceed to next plan step", visible=False)
plan_result = gr.Markdown(visible = True, label="Result of last plan step")
with gr.Tab("Retrieved Docs"):
retrieved_docs_state = gr.State([])
with gr.Row():
gr.Markdown("# Retrieved Docs")
retrieved_docs_btn = gr.Button("Display retrieved docs")
retrieved_docs_display = gr.Markdown()
processed_docs_state = gr.State([])
with gr.Row():
gr.Markdown("# Processed Docs")
processed_docs_btn = gr.Button("Display processed docs")
processed_docs_display = gr.Markdown()
continue_inference_btn.click(
continue_inference,
inputs={graph_state, config_state},
outputs={continue_inference_btn, graph_state, retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, plan_result, retrieved_docs_state}
)
launch_inference.click(
start_inference,
inputs={config_state, user_query, openai_api_key, groq_api_key, langsmith_api_key},
outputs=[dataframe_plan, graph_state]
).then(
update_display,
inputs=dataframe_plan,
outputs={plan_display, select_step_to_modify, enter_new_step, submit_new_step, continue_inference_btn}
)
retrieve_more_docs_btn.click(
retrieve_more_docs,
inputs={graph_state, config_state},
outputs={graph_state, human_validation_title, retrieved_docs_display}
)
continue_to_processing_btn.click(
continue_to_processing,
outputs={retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, process_data_btn, process_steps_nb, process_steps_title}
)
retrieved_docs_btn.click(
fn=lambda docs : format_docs(docs),
inputs=retrieved_docs_state,
outputs=retrieved_docs_display
)
processed_docs_btn.click(
fn=lambda docs : format_docs(docs),
inputs=processed_docs_state,
outputs=processed_docs_display
)
test_process_steps = gr.Button(value="Test process steps")
test_process_steps.click(
lambda : [gr.update(visible = True), gr.update(visible=True)],
outputs=[process_steps_nb, process_steps_title]
)
demo.launch()