LLM-ADE-dev / src /app.py
WilliamGazeley
Implement retry
b1f79b2
raw
history blame
5.22 kB
from time import time
from pprint import pprint
import huggingface_hub
import streamlit as st
from typing import Literal, Dict
from typing_extensions import TypedDict
import langchain
from langgraph.graph import END, StateGraph
from langchain_community.chat_models import ChatOllama
from logger import logger
from config import config
from agents import get_agents, tools_dict
class GraphState(TypedDict):
"""Represents the state of the graph."""
question: str
rephrased_question: str
function_agent_output: str
generation: str
@st.cache_resource(show_spinner="Loading model..")
def init_agents() -> dict[str, langchain.agents.AgentExecutor]:
huggingface_hub.login(token=config.hf_token, new_session=False)
llm = ChatOllama(model = config.ollama_model, temperature = 0.8)
return get_agents(llm)
# Nodes -----------------------------------------------------------------------
def question_node(state: GraphState) -> Dict[str, str]:
"""
Generate a question for the function agent.
"""
logger.info("Generating question for function agent")
# config.status.update(label=":question: Breaking down question")
question = state["question"]
logger.info(f"Original question: {question}")
rephrased_question = agents["rephrase_agent"].invoke({"question": question})
logger.info(f"Rephrased question: {rephrased_question}")
return {"rephrased_question": rephrased_question}
def function_agent_node(state: GraphState) -> Literal["finished"]:
"""
Call the function agent
"""
logger.info("Calling function agent")
question = state["rephrased_question"]
response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output")
# config.status.update(label=":brain: Analysing data..")
logger.info(f"Function agent output: {response}")
return {"function_agent_output": response}
def output_node(state: GraphState) -> Dict[str, str]:
"""
Generate the final output
"""
logger.info("Generating output")
# config.status.update(label=":bulb: Preparing response..")
generation = agents["output_agent"].invoke({"context": state["function_agent_output"],
"question": state["rephrased_question"]})
return {"generation": generation}
# Conditional Edge ------------------------------------------------------------
def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]:
"""
Route quesition to web search or RAG
"""
logger.info("Routing question")
# config.state.update(label=":chart_with_upwards_trend: Routing question")
question = state["question"]
logger.info(f"Question: {question}")
source = agents["router_agent"].invoke({"question": question})
logger.info(source)
logger.info(source["datasource"])
if source["datasource"] == "vectorstore":
return "vectorstore"
elif source["datasource"] == "websearch":
return "websearch"
# Graph -----------------------------------------------------------------------
workflow = StateGraph(GraphState)
workflow.add_node("question_rephrase", question_node)
workflow.add_node("function_agent", function_agent_node)
workflow.add_node("output_node", output_node)
workflow.set_entry_point("question_rephrase")
workflow.add_edge("question_rephrase", "function_agent")
workflow.add_edge("function_agent", "output_node")
workflow.set_finish_point("output_node")
flow = workflow.compile()
progress_map = {
"question_rephrase": ":mag: Collecting data",
"function_agent": ":bulb: Preparing response",
"output_node": ":bulb: Done!",
}
def main():
st.title("LLM-ADE 9B Demo")
input_text = st.text_area("Enter your text here:", value="", height=200)
def get_response(input_text: str) -> str:
try:
for output in flow.stream({"question": input_text}):
for key, value in output.items():
config.status.update(label=progress_map[key])
pprint(f"Finished running: {key}")
return value["generation"]
except Exception as e:
logger.error(e)
logger.info("Retrying..")
get_response(input_text)
if st.button("Generate"):
if input_text:
with st.status("Generating response...") as status:
config.status = status
config.status.update(label=":question: Breaking down question")
response = get_response(input_text)
st.write(response)
config.status.update(label="Finished!", state="complete", expanded=True)
else:
st.warning("Please enter some text to generate a response.")
def main_headless(prompt: str):
start = time()
for output in flow.stream({"question": prompt}):
for key, value in output.items():
pprint(f"Finished running: {key}")
print("\033[94m" + value["generation"] + "\033[0m")
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
agents = init_agents()
if __name__ == "__main__":
if config.headless:
import fire
fire.Fire(main_headless)
else:
main()