Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Annotated, Sequence, Literal
|
2 |
+
|
3 |
+
from functools import lru_cache
|
4 |
+
from langchain_core.messages import BaseMessage
|
5 |
+
from langchain_anthropic import ChatAnthropic
|
6 |
+
from langchain_openai import ChatOpenAI
|
7 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
8 |
+
from langgraph.prebuilt import ToolNode
|
9 |
+
from langgraph.graph import StateGraph, END, add_messages
|
10 |
+
|
11 |
+
tools = [TavilySearchResults(max_results=1)]
|
12 |
+
|
13 |
+
@lru_cache(maxsize=4)
|
14 |
+
def _get_model(model_name: str):
|
15 |
+
if model_name == "openai":
|
16 |
+
model = ChatOpenAI(temperature=0, model_name="gpt-4o")
|
17 |
+
elif model_name == "anthropic":
|
18 |
+
model = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229")
|
19 |
+
else:
|
20 |
+
raise ValueError(f"Unsupported model type: {model_name}")
|
21 |
+
|
22 |
+
model = model.bind_tools(tools)
|
23 |
+
return model
|
24 |
+
|
25 |
+
|
26 |
+
class AgentState(TypedDict):
|
27 |
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
28 |
+
|
29 |
+
|
30 |
+
# Define the function that determines whether to continue or not
|
31 |
+
def should_continue(state):
|
32 |
+
messages = state["messages"]
|
33 |
+
last_message = messages[-1]
|
34 |
+
# If there are no tool calls, then we finish
|
35 |
+
if not last_message.tool_calls:
|
36 |
+
return "end"
|
37 |
+
# Otherwise if there is, we continue
|
38 |
+
else:
|
39 |
+
return "continue"
|
40 |
+
|
41 |
+
|
42 |
+
system_prompt = """Be a helpful assistant"""
|
43 |
+
|
44 |
+
# Define the function that calls the model
|
45 |
+
def call_model(state, config):
|
46 |
+
messages = state["messages"]
|
47 |
+
messages = [{"role": "system", "content": system_prompt}] + messages
|
48 |
+
model_name = config.get('configurable', {}).get("model_name", "anthropic")
|
49 |
+
model = _get_model(model_name)
|
50 |
+
response = model.invoke(messages)
|
51 |
+
# We return a list, because this will get added to the existing list
|
52 |
+
return {"messages": [response]}
|
53 |
+
|
54 |
+
|
55 |
+
# Define the function to execute tools
|
56 |
+
tool_node = ToolNode(tools)
|
57 |
+
|
58 |
+
# Define the config
|
59 |
+
class GraphConfig(TypedDict):
|
60 |
+
model_name: Literal["anthropic", "openai"]
|
61 |
+
|
62 |
+
|
63 |
+
# Define a new graph
|
64 |
+
workflow = StateGraph(AgentState, config_schema=GraphConfig)
|
65 |
+
|
66 |
+
# Define the two nodes we will cycle between
|
67 |
+
workflow.add_node("agent", call_model)
|
68 |
+
workflow.add_node("action", tool_node)
|
69 |
+
|
70 |
+
# Set the entrypoint as `agent`
|
71 |
+
# This means that this node is the first one called
|
72 |
+
workflow.set_entry_point("agent")
|
73 |
+
|
74 |
+
# We now add a conditional edge
|
75 |
+
workflow.add_conditional_edges(
|
76 |
+
# First, we define the start node. We use `agent`.
|
77 |
+
# This means these are the edges taken after the `agent` node is called.
|
78 |
+
"agent",
|
79 |
+
# Next, we pass in the function that will determine which node is called next.
|
80 |
+
should_continue,
|
81 |
+
# Finally we pass in a mapping.
|
82 |
+
# The keys are strings, and the values are other nodes.
|
83 |
+
# END is a special node marking that the graph should finish.
|
84 |
+
# What will happen is we will call `should_continue`, and then the output of that
|
85 |
+
# will be matched against the keys in this mapping.
|
86 |
+
# Based on which one it matches, that node will then be called.
|
87 |
+
{
|
88 |
+
# If `tools`, then we call the tool node.
|
89 |
+
"continue": "action",
|
90 |
+
# Otherwise we finish.
|
91 |
+
"end": END,
|
92 |
+
},
|
93 |
+
)
|
94 |
+
|
95 |
+
# We now add a normal edge from `tools` to `agent`.
|
96 |
+
# This means that after `tools` is called, `agent` node is called next.
|
97 |
+
workflow.add_edge("action", "agent")
|
98 |
+
|
99 |
+
# Finally, we compile it!
|
100 |
+
# This compiles it into a LangChain Runnable,
|
101 |
+
# meaning you can use it as you would any other runnable
|
102 |
+
graph = workflow.compile()
|