from time import time from pprint import pprint import huggingface_hub import streamlit as st from typing import Literal, Dict from typing_extensions import TypedDict import langchain from langgraph.graph import END, StateGraph from langchain_community.chat_models import ChatOllama from logger import logger from config import config from agents import get_agents, tools_dict class GraphState(TypedDict): """Represents the state of the graph.""" question: str rephrased_question: str function_agent_output: str generation: str @st.cache_resource(show_spinner="Loading model..") def init_agents() -> dict[str, langchain.agents.AgentExecutor]: huggingface_hub.login(token=config.hf_token, new_session=False) llm = ChatOllama(model = config.ollama_model, temperature = 0.8) return get_agents(llm) # Nodes ----------------------------------------------------------------------- def question_node(state: GraphState) -> Dict[str, str]: """ Generate a question for the function agent. """ logger.info("Generating question for function agent") # config.status.update(label=":question: Breaking down question") question = state["question"] logger.info(f"Original question: {question}") rephrased_question = agents["rephrase_agent"].invoke({"question": question}) logger.info(f"Rephrased question: {rephrased_question}") return {"rephrased_question": rephrased_question} def function_agent_node(state: GraphState) -> Literal["finished"]: """ Call the function agent """ logger.info("Calling function agent") question = state["rephrased_question"] response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output") # config.status.update(label=":brain: Analysing data..") logger.info(f"Function agent output: {response}") return {"function_agent_output": response} def output_node(state: GraphState) -> Dict[str, str]: """ Generate the final output """ logger.info("Generating output") # config.status.update(label=":bulb: Preparing response..") generation = agents["output_agent"].invoke({"context": state["function_agent_output"], "question": state["rephrased_question"]}) return {"generation": generation} # Conditional Edge ------------------------------------------------------------ def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]: """ Route quesition to web search or RAG """ logger.info("Routing question") # config.state.update(label=":chart_with_upwards_trend: Routing question") question = state["question"] logger.info(f"Question: {question}") source = agents["router_agent"].invoke({"question": question}) logger.info(source) logger.info(source["datasource"]) if source["datasource"] == "vectorstore": return "vectorstore" elif source["datasource"] == "websearch": return "websearch" # Graph ----------------------------------------------------------------------- workflow = StateGraph(GraphState) workflow.add_node("question_rephrase", question_node) workflow.add_node("function_agent", function_agent_node) workflow.add_node("output_node", output_node) workflow.set_entry_point("question_rephrase") workflow.add_edge("question_rephrase", "function_agent") workflow.add_edge("function_agent", "output_node") workflow.set_finish_point("output_node") flow = workflow.compile() progress_map = { "question_rephrase": ":mag: Collecting data", "function_agent": ":bulb: Preparing response", "output_node": ":bulb: Done!", } def main(): st.title("LLM-ADE 9B Demo") input_text = st.text_area("Enter your text here:", value="", height=200) def get_response(input_text: str) -> str: try: for output in flow.stream({"question": input_text}): for key, value in output.items(): config.status.update(label=progress_map[key]) pprint(f"Finished running: {key}") return value["generation"] except Exception as e: logger.error(e) logger.info("Retrying..") get_response(input_text) if st.button("Generate"): if input_text: with st.status("Generating response...") as status: config.status = status config.status.update(label=":question: Breaking down question") response = get_response(input_text) st.write(response) config.status.update(label="Finished!", state="complete", expanded=True) else: st.warning("Please enter some text to generate a response.") def main_headless(prompt: str): start = time() for output in flow.stream({"question": prompt}): for key, value in output.items(): pprint(f"Finished running: {key}") print("\033[94m" + value["generation"] + "\033[0m") print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) agents = init_agents() if __name__ == "__main__": if config.headless: import fire fire.Fire(main_headless) else: main()