import gradio as gr from langchain_community.graphs import Neo4jGraph import pandas as pd import json from ki_gen.planner import build_planner_graph from ki_gen.utils import clear_memory, init_app, format_df, memory from ki_gen.prompts import get_initial_prompt MAX_PROCESSING_STEPS = 10 from neo4j import GraphDatabase NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io" AUTH = ("neo4j", "P6zQScbmyWivYeVZ84BniNjOCxu1D5Akw1IRC1SLKx8") with GraphDatabase.driver(NEO4J_URI, auth=AUTH) as driver: driver.verify_connectivity() print(driver.verify_connectivity()) print("i guess its gut") def start_inference(data): """ Starts plan generation with user_query as input which gets displayed after """ config = data[config_state] init_app( openai_key=data[openai_api_key], groq_key=data[groq_api_key], langsmith_key=data[langsmith_api_key] ) #TO DO function : clear_memory #clear_memory(memory, config["configurable"].get("thread_id")) graph = build_planner_graph(memory, config["configurable"]) with open("images/graph_png.png", "wb") as f: f.write(graph.get_graph(xray=1).draw_mermaid_png()) print("here !") for event in graph.stream(get_initial_prompt(config, data[user_query]), config, stream_mode="values"): if "messages" in event: event["messages"][-1].pretty_print() state = graph.get_state(config) steps = [i for i in range(1,len(state.values['store_plan'])+1)] df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']}) return [df, graph] def update_display(df): """ Displays the df after it has been generated """ formatted_html = format_df(df) return { plan_display : gr.update(visible=True, value = formatted_html), select_step_to_modify : gr.update(visible=True, value=0), enter_new_step : gr.update(visible=True), submit_new_step : gr.update(visible=True), continue_inference_btn : gr.update(visible=True) } def format_docs(docs: list[dict]): formatted_results = "" for i, doc in enumerate(docs): formatted_results += f"\n### Document {i}\n" for key in doc: formatted_results += f"**{key}**: {doc[key]}\n" return formatted_results def continue_inference(data): """ Proceeds to next plan step """ graph = data[graph_state] config = data[config_state] for event in graph.stream(None, config, stream_mode="values"): if "messages" in event: event["messages"][-1].pretty_print() snapshot = graph.get_state(config) print(f"DEBUG INFO : next : {snapshot.next}") print(f"DEBUG INFO ++ L.75: {snapshot}") if snapshot.next and snapshot.next[0] == "human_validation": return { continue_inference_btn : gr.update(visible=False), graph_state : graph, retrieve_more_docs_btn : gr.update(visible=True), continue_to_processing_btn : gr.update(visible=True), human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"), retrieved_docs_state : snapshot.values['valid_docs'] } return { plan_result : snapshot.values["messages"][-1].content, graph_state : graph, continue_inference_btn : gr.update(visible=False) } def continue_to_processing(): """ Continue to doc processing configuration """ return { retrieve_more_docs_btn : gr.update(visible=False), continue_to_processing_btn : gr.update(visible=False), human_validation_title : gr.update(visible=False), process_data_btn : gr.update(visible=True), process_steps_nb : gr.update(visible=True), process_steps_title : gr.update(visible=True) } def retrieve_more_docs(data): """ Restart doc retrieval For now we simply regenerate the cypher, it may be different because temperature != 0 """ graph = data[graph_state] config = data[config_state] graph.update_state(config, {'human_validated' : False}, as_node="human_validation") for event in graph.stream(None, config, stream_mode="values"): if "messages" in event: event["messages"][-1].pretty_print() snapshot = graph.get_state(config) print(f"DEBUG INFO : next : {snapshot.next}") print(f"DEBUG INFO ++ L.121: {snapshot}") return { graph_state : graph, human_validation_title : gr.update(visible=True, value=f"**{len(snapshot.values['valid_docs'])} documents retrieved.** Retrieve more or continue ?"), retrieved_docs_display : format_docs(snapshot.values['valid_docs']) } def execute_processing(*args): """ Execute doc processing Args are passed as a list and not a dict for syntax convenience """ graph = args[-2] config = args[-1] nb_process_steps = args[-3] process_steps = [] for i in range (nb_process_steps): if args[i] == "custom": process_steps.append({"prompt" : args[nb_process_steps + i], "context" : args[2*nb_process_steps + i], "processing_model" : args[3*nb_process_steps + i]}) else: process_steps.append(args[i]) graph.update_state(config, {'human_validated' : True, 'process_steps' : process_steps}, as_node="human_validation") for event in graph.stream(None, config, stream_mode="values"): if "messages" in event: event["messages"][-1].pretty_print() snapshot = graph.get_state(config) print(f"DEBUG INFO : next : {snapshot.next}") print(f"DEBUG INFO ++ L.153: {snapshot}") return { plan_result : snapshot.values["messages"][-1].content, processed_docs_state : snapshot.values["valid_docs"], graph_state : graph, continue_inference_btn : gr.update(visible=True), process_steps_nb : gr.update(value=0, visible=False), process_steps_title : gr.update(visible=False), process_data_btn : gr.update(visible=False), } def update_config_display(): """ Called after loading the config.json file TODO : allow the user to specify a path to the config file """ with open("config.json", "r") as config_file: config = json.load(config_file) print(config) return { main_llm : config["main_llm"], plan_method : config["plan_method"], use_detailed_query : config["use_detailed_query"], cypher_gen_method : config["cypher_gen_method"], validate_cypher : config["validate_cypher"], summarization_model : config["summarize_model"], eval_method : config["eval_method"], eval_threshold : config["eval_threshold"], max_docs : config["max_docs"], compression_method : config["compression_method"], compress_rate : config["compress_rate"], force_tokens : config["force_tokens"], eval_model : config["eval_model"], srv_addr : config["graph"]["address"], srv_usr : config["graph"]["username"], srv_pwd : config["graph"]["password"], openai_api_key : config["openai_api_key"], groq_api_key : config["groq_api_key"], langsmith_api_key : config["langsmith_api_key"] } def build_config(data): """ Build the config variable using the values inputted by the user """ config = {} config["main_llm"] = data[main_llm] config["plan_method"] = data[plan_method] config["use_detailed_query"] = data[use_detailed_query] config["cypher_gen_method"] = data[cypher_gen_method] config["validate_cypher"] = data[validate_cypher] config["summarize_model"] = data[summarization_model] config["eval_method"] = data[eval_method] config["eval_threshold"] = data[eval_threshold] config["max_docs"] = data[max_docs] config["compression_method"] = data[compression_method] config["compress_rate"] = data[compress_rate] config["force_tokens"] = data[force_tokens] config["eval_model"] = data[eval_model] config["thread_id"] = "3" try: neograph = Neo4jGraph(url=data[srv_addr], username=data[srv_usr], password=data[srv_pwd]) config["graph"] = neograph except Exception as e: raise gr.Error(f"Error when configuring the neograph server : {e}", duration=5) gr.Info("Succesfully updated configuration !", duration=5) return {"configurable" : config} with gr.Blocks() as demo: with gr.Tab("Config"): ### The config tab gr.Markdown("## Config options setup") gr.Markdown("### API Keys") with gr.Row(): openai_api_key = gr.Textbox( label="OpenAI API Key", type="password" ) groq_api_key = gr.Textbox( label="Groq API Key", type='password' ) langsmith_api_key = gr.Textbox( label="LangSmith API Key", type="password" ) gr.Markdown('### Planner options') with gr.Row(): main_llm = gr.Dropdown( choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768"], label="Main LLM", info="Choose the LLM which will perform the generation", value="gpt-4o" ) with gr.Column(scale=1, min_width=600): plan_method = gr.Dropdown( choices=["generation", "modification"], label="Planning method", info="Choose how the main LLM will generate its plan", value="modification" ) use_detailed_query = gr.Checkbox( label="Detail each plan step", info="Detail each plan step before passing it for data query" ) gr.Markdown("### Data query options") # The options for the data processor # TODO : remove the options for summarize and compress and let the user choose them when specifying processing steps # (similarly to what is done for custom processing step) with gr.Row(): with gr.Column(scale=1, min_width=300): # Neo4j Server parameters srv_addr = gr.Textbox( label="Neo4j server address", placeholder="localhost:7687" ) srv_usr = gr.Textbox( label="Neo4j username", placeholder="neo4j" ) srv_pwd = gr.Textbox( label="Neo4j password", placeholder="" ) with gr.Column(scale=1, min_width=300): cypher_gen_method = gr.Dropdown( choices=["auto", "guided"], label="Cypher generation method", ) validate_cypher = gr.Checkbox( label="Validate cypher using graph Schema" ) summarization_model = gr.Dropdown( choices=["gpt-4o", "claude-3-5-sonnet", "mixtral-8x7b-32768", "llama3-70b-8192"], label="Summarization LLM", info="Choose the LLM which will perform the summaries" ) with gr.Column(scale=1, min_width=300): eval_method = gr.Dropdown( choices=["binary", "score"], label="Retrieved docs evaluation method", info="Evaluation method of retrieved docs" ) eval_model = gr.Dropdown( choices = ["gpt-4o", "mixtral-8x7b-32768"], label = "Evaluation model", info = "The LLM to use to evaluate the relevance of retrieved docs", value = "mixtral-8x7b-32768" ) eval_threshold = gr.Slider( minimum=0, maximum=1, value=0.7, label="Eval threshold", info="Score above which a doc is considered relevant", step=0.01, visible=False ) def eval_method_changed(selection): if selection == "score": return gr.update(visible=True) return gr.update(visible=False) eval_method.change(eval_method_changed, inputs=eval_method, outputs=eval_threshold) max_docs= gr.Slider( minimum=0, maximum = 30, value = 15, label="Max docs", info="Maximum number of docs to be retrieved at each query", step=0.01 ) with gr.Column(scale=1, min_width=300): compression_method = gr.Dropdown( choices=["llm_lingua2", "llm_lingua"], label="Compression method", value="llm_lingua2" ) with gr.Row(): # Add compression rate configuration with a gr.slider compress_rate = gr.Slider( minimum = 0, maximum = 1, value = 0.33, label="Compression rate", info="Compression rate", step = 0.01 ) # Add gr.CheckboxGroup to choose force_tokens force_tokens = gr.CheckboxGroup( choices=['\n', '?', '.', '!', ','], value=[], label="Force tokens", info="Tokens to keep during compression", ) with gr.Row(): btn_update_config = gr.Button(value="Update config") load_config_json = gr.Button(value="Load config from JSON") with gr.Row(): debug_info = gr.Button(value="Print debug info") config_state = gr.State(value={}) btn_update_config.click( build_config, inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \ compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs}, outputs=config_state ) load_config_json.click( update_config_display, outputs={main_llm, plan_method, use_detailed_query, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, \ max_docs, compress_rate, compression_method, force_tokens, eval_model, srv_addr, srv_usr, srv_pwd, openai_api_key, langsmith_api_key, groq_api_key} ).then( build_config, inputs={main_llm, plan_method, use_detailed_query, srv_addr, srv_pwd, srv_usr, compression_method, eval_model, \ compress_rate, force_tokens, cypher_gen_method, validate_cypher, summarization_model, eval_method, eval_threshold, max_docs}, outputs=config_state ) # Print config variable in the terminal debug_info.click(lambda x : print(x), inputs=config_state) with gr.Tab("Inference"): ### Inference tab graph_state = gr.State() user_query = gr.Textbox(label = "Your query") launch_inference = gr.Button(value="Generate plan") with gr.Row(): dataframe_plan = gr.Dataframe(visible = False) plan_display = gr.HTML(visible = False, label="Generated plan") with gr.Column(): # Lets the user modify steps of the plan. Underlying logic not implemented yet # TODO : implement this with gr.Row(): select_step_to_modify = gr.Number(visible= False, label="Select a plan step to modify", value=0) submit_new_step = gr.Button(visible = False, value="Submit new step") enter_new_step = gr.Textbox(visible=False, label="Modify the plan step") with gr.Row(): human_validation_title = gr.Markdown(visible=False) retrieve_more_docs_btn = gr.Button(value="Retrieve more docs", visible=False) continue_to_processing_btn = gr.Button(value="Proceed to data processing", visible=False) with gr.Row(): with gr.Column(): process_steps_title = gr.Markdown("#### Data processing steps", visible=False) process_steps_nb = gr.Number(label="Number of processing steps", value = 0, precision=0, step = 1, visible=False) def get_process_step_names(): return ["summarize", "compress", "custom"] # The gr.render decorator allows the code inside the following function to be rerun everytime the 'inputs' variable is modified # /!\ All event listeners that use variables defined inside a gr.render function must be defined inside that same function # ref : https://www.gradio.app/docs/gradio/render @gr.render(inputs=process_steps_nb) def processing(nb): with gr.Row(): process_step_names = get_process_step_names() dropdowns = [] textboxes = [] usable_elements = [] processing_models = [] for i in range(nb): with gr.Column(): dropdown = gr.Dropdown(key = f"d{i}", choices=process_step_names, label=f"Data processing step {i+1}") dropdowns.append(dropdown) textbox = gr.Textbox( key = f"t{i}", value="", placeholder="Your custom prompt", visible=True, min_width=300) textboxes.append(textbox) usable_element = gr.Dropdown( key = f"u{i}", choices = [(j) for j in range(i+1)], label="Elements passed to the LLM for this process step", multiselect=True, ) usable_elements.append(usable_element) processing_model = gr.Dropdown( key = f"m{i}", label="The LLM that will execute this step", visible=True, choices=["gpt-4o", "mixtral-8x7b-32768", "llama3-70b-8182"] ) processing_models.append(processing_model) dropdown.change( fn=lambda process_name : [gr.update(visible=(process_name=="custom")), gr.update(visible=(process_name=='custom')), gr.update(visible=(process_name=='custom'))], inputs=dropdown, outputs=[textbox, usable_element, processing_model] ) process_data_btn.click( execute_processing, inputs= dropdowns + textboxes + usable_elements + processing_models + [process_steps_nb, graph_state, config_state], outputs={plan_result, processed_docs_state, graph_state, continue_inference_btn, process_steps_nb, process_steps_title, process_data_btn} ) process_data_btn = gr.Button(value="Process retrieved docs", visible=False) continue_inference_btn = gr.Button(value="Proceed to next plan step", visible=False) plan_result = gr.Markdown(visible = True, label="Result of last plan step") with gr.Tab("Retrieved Docs"): retrieved_docs_state = gr.State([]) with gr.Row(): gr.Markdown("# Retrieved Docs") retrieved_docs_btn = gr.Button("Display retrieved docs") retrieved_docs_display = gr.Markdown() processed_docs_state = gr.State([]) with gr.Row(): gr.Markdown("# Processed Docs") processed_docs_btn = gr.Button("Display processed docs") processed_docs_display = gr.Markdown() continue_inference_btn.click( continue_inference, inputs={graph_state, config_state}, outputs={continue_inference_btn, graph_state, retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, plan_result, retrieved_docs_state} ) launch_inference.click( start_inference, inputs={config_state, user_query, openai_api_key, groq_api_key, langsmith_api_key}, outputs=[dataframe_plan, graph_state] ).then( update_display, inputs=dataframe_plan, outputs={plan_display, select_step_to_modify, enter_new_step, submit_new_step, continue_inference_btn} ) retrieve_more_docs_btn.click( retrieve_more_docs, inputs={graph_state, config_state}, outputs={graph_state, human_validation_title, retrieved_docs_display} ) continue_to_processing_btn.click( continue_to_processing, outputs={retrieve_more_docs_btn, continue_to_processing_btn, human_validation_title, process_data_btn, process_steps_nb, process_steps_title} ) retrieved_docs_btn.click( fn=lambda docs : format_docs(docs), inputs=retrieved_docs_state, outputs=retrieved_docs_display ) processed_docs_btn.click( fn=lambda docs : format_docs(docs), inputs=processed_docs_state, outputs=processed_docs_display ) test_process_steps = gr.Button(value="Test process steps") test_process_steps.click( lambda : [gr.update(visible = True), gr.update(visible=True)], outputs=[process_steps_nb, process_steps_title] ) demo.launch()