Spaces:
Sleeping
Sleeping
File size: 5,440 Bytes
9e2a95f 3651997 a20dfac ecd63b4 3651997 c124df1 3651997 ecd63b4 a818c02 0583c4b 3651997 5894c9b 3651997 bbded71 3651997 bbded71 3651997 bbded71 3651997 bbded71 3651997 9e2a95f bbded71 b1f79b2 bbded71 ecd63b4 9e2a95f ecd63b4 b1f79b2 cec1525 b1f79b2 cec1525 9e2a95f 33e451f c9d6063 ecd63b4 9e2a95f 7aab4a8 b1f79b2 33e451f c9d6063 ecd63b4 9e2a95f 3651997 9e2a95f ecd63b4 3651997 e40d8d8 c124df1 5894c9b 9e2a95f 5894c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from time import time
from pprint import pprint
import huggingface_hub
import streamlit as st
from typing import Literal, Dict
from typing_extensions import TypedDict
import langchain
from langgraph.graph import END, StateGraph
from langchain_community.chat_models import ChatOllama
from logger import logger
from config import config
from agents import get_agents, tools_dict
class GraphState(TypedDict):
"""Represents the state of the graph."""
question: str
rephrased_question: str
function_agent_output: str
generation: str
@st.cache_resource(show_spinner="Loading model..")
def init_agents() -> dict[str, langchain.agents.AgentExecutor]:
huggingface_hub.login(token=config.hf_token, new_session=False)
llm = ChatOllama(model = config.ollama_model, temperature = 0.8)
return get_agents(llm)
# Nodes -----------------------------------------------------------------------
def question_node(state: GraphState) -> Dict[str, str]:
"""
Generate a question for the function agent.
"""
logger.info("Generating question for function agent")
# config.status.update(label=":question: Breaking down question")
question = state["question"]
logger.info(f"Original question: {question}")
rephrased_question = agents["rephrase_agent"].invoke({"question": question})
logger.info(f"Rephrased question: {rephrased_question}")
return {"rephrased_question": rephrased_question}
def function_agent_node(state: GraphState) -> Literal["finished"]:
"""
Call the function agent
"""
logger.info("Calling function agent")
question = state["rephrased_question"]
response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output")
# config.status.update(label=":brain: Analysing data..")
logger.info(f"Function agent output: {response}")
return {"function_agent_output": response}
def output_node(state: GraphState) -> Dict[str, str]:
"""
Generate the final output
"""
logger.info("Generating output")
# config.status.update(label=":bulb: Preparing response..")
generation = agents["output_agent"].invoke({"context": state["function_agent_output"],
"question": state["rephrased_question"]})
return {"generation": generation}
# Conditional Edge ------------------------------------------------------------
def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]:
"""
Route quesition to web search or RAG
"""
logger.info("Routing question")
# config.state.update(label=":chart_with_upwards_trend: Routing question")
question = state["question"]
logger.info(f"Question: {question}")
source = agents["router_agent"].invoke({"question": question})
logger.info(source)
logger.info(source["datasource"])
if source["datasource"] == "vectorstore":
return "vectorstore"
elif source["datasource"] == "websearch":
return "websearch"
# Graph -----------------------------------------------------------------------
workflow = StateGraph(GraphState)
workflow.add_node("question_rephrase", question_node)
workflow.add_node("function_agent", function_agent_node)
workflow.add_node("output_node", output_node)
workflow.set_entry_point("question_rephrase")
workflow.add_edge("question_rephrase", "function_agent")
workflow.add_edge("function_agent", "output_node")
workflow.set_finish_point("output_node")
flow = workflow.compile()
progress_map = {
"question_rephrase": ":mag: Collecting data",
"function_agent": ":bulb: Preparing response",
"output_node": ":bulb: Done!",
}
def main():
st.title("LLM-ADE 9B Demo")
input_text = st.text_area("Enter your text here:", value="", height=200)
def get_response(input_text: str, depth: int = 1) -> str:
try:
for output in flow.stream({"question": input_text}):
for key, value in output.items():
config.status.update(label=progress_map[key])
pprint(f"Finished running: {key}")
return value["generation"]
except Exception as e:
logger.error(e)
logger.info("Retrying..")
if depth < 5:
return get_response(input_text, depth + 1)
if st.button("Generate") or input_text:
start = time()
if input_text:
with st.status("Generating response...") as status:
config.status = status
config.status.update(label=":question: Breaking down question")
response = get_response(input_text)
response = response.replace("$", "\$") # Escape $ to avoid LaTeX rendering
st.info(response)
config.status.update(label=f"Finished! ({time() - start:.2f}s)", state="complete", expanded=True)
else:
st.warning("Please enter some text to generate a response.")
def main_headless(prompt: str):
start = time()
for output in flow.stream({"question": prompt}):
for key, value in output.items():
pprint(f"Finished running: {key}")
print("\033[94m" + value["generation"] + "\033[0m")
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
agents = init_agents()
if __name__ == "__main__":
if config.headless:
import fire
fire.Fire(main_headless)
else:
main()
|