abderrahimbrighal commited on
Commit
9ecfc26
·
verified ·
1 Parent(s): 228261b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Annotated, Sequence, Literal
2
+
3
+ from functools import lru_cache
4
+ from langchain_core.messages import BaseMessage
5
+ from langchain_anthropic import ChatAnthropic
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langgraph.prebuilt import ToolNode
9
+ from langgraph.graph import StateGraph, END, add_messages
10
+
11
+ tools = [TavilySearchResults(max_results=1)]
12
+
13
+ @lru_cache(maxsize=4)
14
+ def _get_model(model_name: str):
15
+ if model_name == "openai":
16
+ model = ChatOpenAI(temperature=0, model_name="gpt-4o")
17
+ elif model_name == "anthropic":
18
+ model = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229")
19
+ else:
20
+ raise ValueError(f"Unsupported model type: {model_name}")
21
+
22
+ model = model.bind_tools(tools)
23
+ return model
24
+
25
+
26
+ class AgentState(TypedDict):
27
+ messages: Annotated[Sequence[BaseMessage], add_messages]
28
+
29
+
30
+ # Define the function that determines whether to continue or not
31
+ def should_continue(state):
32
+ messages = state["messages"]
33
+ last_message = messages[-1]
34
+ # If there are no tool calls, then we finish
35
+ if not last_message.tool_calls:
36
+ return "end"
37
+ # Otherwise if there is, we continue
38
+ else:
39
+ return "continue"
40
+
41
+
42
+ system_prompt = """Be a helpful assistant"""
43
+
44
+ # Define the function that calls the model
45
+ def call_model(state, config):
46
+ messages = state["messages"]
47
+ messages = [{"role": "system", "content": system_prompt}] + messages
48
+ model_name = config.get('configurable', {}).get("model_name", "anthropic")
49
+ model = _get_model(model_name)
50
+ response = model.invoke(messages)
51
+ # We return a list, because this will get added to the existing list
52
+ return {"messages": [response]}
53
+
54
+
55
+ # Define the function to execute tools
56
+ tool_node = ToolNode(tools)
57
+
58
+ # Define the config
59
+ class GraphConfig(TypedDict):
60
+ model_name: Literal["anthropic", "openai"]
61
+
62
+
63
+ # Define a new graph
64
+ workflow = StateGraph(AgentState, config_schema=GraphConfig)
65
+
66
+ # Define the two nodes we will cycle between
67
+ workflow.add_node("agent", call_model)
68
+ workflow.add_node("action", tool_node)
69
+
70
+ # Set the entrypoint as `agent`
71
+ # This means that this node is the first one called
72
+ workflow.set_entry_point("agent")
73
+
74
+ # We now add a conditional edge
75
+ workflow.add_conditional_edges(
76
+ # First, we define the start node. We use `agent`.
77
+ # This means these are the edges taken after the `agent` node is called.
78
+ "agent",
79
+ # Next, we pass in the function that will determine which node is called next.
80
+ should_continue,
81
+ # Finally we pass in a mapping.
82
+ # The keys are strings, and the values are other nodes.
83
+ # END is a special node marking that the graph should finish.
84
+ # What will happen is we will call `should_continue`, and then the output of that
85
+ # will be matched against the keys in this mapping.
86
+ # Based on which one it matches, that node will then be called.
87
+ {
88
+ # If `tools`, then we call the tool node.
89
+ "continue": "action",
90
+ # Otherwise we finish.
91
+ "end": END,
92
+ },
93
+ )
94
+
95
+ # We now add a normal edge from `tools` to `agent`.
96
+ # This means that after `tools` is called, `agent` node is called next.
97
+ workflow.add_edge("action", "agent")
98
+
99
+ # Finally, we compile it!
100
+ # This compiles it into a LangChain Runnable,
101
+ # meaning you can use it as you would any other runnable
102
+ graph = workflow.compile()