Spaces:
Sleeping
Sleeping
Merge pull request #1 from ric9176/memory-store
Browse filesAdd long term memory store, refactor chainlit interface, re-structure project
- .gitignore +3 -1
- .langgraph_api/.langgraph_checkpoint.1.pckl +0 -0
- .langgraph_api/.langgraph_checkpoint.2.pckl +0 -0
- .langgraph_api/.langgraph_ops.pckl +0 -0
- .langgraph_api/.langgraph_retry_counter.pckl +0 -0
- .langgraph_api/store.pckl +0 -0
- .langgraph_api/store.vectors.pckl +0 -0
- agent/__init__.py +2 -2
- agent/agent.py +0 -50
- agent/graph.py +96 -0
- agent/utils/nodes.py +187 -23
- agent/utils/state.py +1 -1
- app.py +78 -87
- docker-compose.yml +23 -0
- langgraph.json +1 -1
- studio_example_message.json +4 -0
- tools.py +0 -28
.gitignore
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
__pycache__/
|
2 |
.chainlit/
|
3 |
.venv/
|
4 |
-
.env
|
|
|
|
|
|
1 |
__pycache__/
|
2 |
.chainlit/
|
3 |
.venv/
|
4 |
+
.env
|
5 |
+
.langgraph_api/
|
6 |
+
data/
|
.langgraph_api/.langgraph_checkpoint.1.pckl
DELETED
Binary file (25.2 kB)
|
|
.langgraph_api/.langgraph_checkpoint.2.pckl
DELETED
Binary file (6.73 kB)
|
|
.langgraph_api/.langgraph_ops.pckl
DELETED
Binary file (7.94 kB)
|
|
.langgraph_api/.langgraph_retry_counter.pckl
DELETED
Binary file (83 Bytes)
|
|
.langgraph_api/store.pckl
DELETED
Binary file (6 Bytes)
|
|
.langgraph_api/store.vectors.pckl
DELETED
Binary file (6 Bytes)
|
|
agent/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from agent.
|
2 |
|
3 |
-
__all__ = ["
|
|
|
1 |
+
from agent.graph import create_agent_graph, create_agent_graph_without_memory, get_checkpointer, langgraph_studio_graph
|
2 |
|
3 |
+
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer", "langgraph_studio_graph"]
|
agent/agent.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
from langgraph.graph import StateGraph
|
2 |
-
from langgraph.checkpoint.memory import MemorySaver
|
3 |
-
|
4 |
-
from agent.utils.state import AgentState
|
5 |
-
from agent.utils.nodes import call_model, tool_node, should_continue
|
6 |
-
|
7 |
-
def create_agent_graph():
|
8 |
-
# Create the graph
|
9 |
-
builder = StateGraph(AgentState)
|
10 |
-
|
11 |
-
# Add nodes
|
12 |
-
builder.add_node("agent", call_model)
|
13 |
-
builder.add_node("action", tool_node)
|
14 |
-
|
15 |
-
# Update edges
|
16 |
-
builder.set_entry_point("agent")
|
17 |
-
builder.add_conditional_edges(
|
18 |
-
"agent",
|
19 |
-
should_continue,
|
20 |
-
)
|
21 |
-
builder.add_edge("action", "agent")
|
22 |
-
|
23 |
-
# Initialize memory saver for conversation persistence
|
24 |
-
memory = MemorySaver()
|
25 |
-
|
26 |
-
# Compile the graph with memory
|
27 |
-
return builder.compile(checkpointer=memory)
|
28 |
-
|
29 |
-
def create_agent_graph_without_memory():
|
30 |
-
# Create the graph
|
31 |
-
builder = StateGraph(AgentState)
|
32 |
-
|
33 |
-
# Add nodes
|
34 |
-
builder.add_node("agent", call_model)
|
35 |
-
builder.add_node("action", tool_node)
|
36 |
-
|
37 |
-
# Update edges
|
38 |
-
builder.set_entry_point("agent")
|
39 |
-
builder.add_conditional_edges(
|
40 |
-
"agent",
|
41 |
-
should_continue,
|
42 |
-
)
|
43 |
-
builder.add_edge("action", "agent")
|
44 |
-
|
45 |
-
# Compile the graph without memory
|
46 |
-
return builder.compile()
|
47 |
-
|
48 |
-
# Create both graph variants
|
49 |
-
graph_with_memory = create_agent_graph()
|
50 |
-
graph = create_agent_graph_without_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/graph.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.graph import StateGraph, END
|
2 |
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
3 |
+
from langgraph.store.memory import InMemoryStore
|
4 |
+
import aiosqlite
|
5 |
+
from types import TracebackType
|
6 |
+
from typing import Optional, Type
|
7 |
+
|
8 |
+
from agent.utils.state import AgentState
|
9 |
+
from agent.utils.nodes import (
|
10 |
+
call_model,
|
11 |
+
tool_node,
|
12 |
+
write_memory,
|
13 |
+
should_continue
|
14 |
+
)
|
15 |
+
|
16 |
+
def create_graph_builder():
|
17 |
+
"""Create a base graph builder with nodes and edges configured."""
|
18 |
+
builder = StateGraph(AgentState)
|
19 |
+
|
20 |
+
|
21 |
+
# Add nodes
|
22 |
+
builder.add_node("agent", call_model)
|
23 |
+
builder.add_node("action", tool_node)
|
24 |
+
builder.add_node("write_memory", write_memory)
|
25 |
+
|
26 |
+
# Set entry point
|
27 |
+
builder.set_entry_point("agent")
|
28 |
+
|
29 |
+
|
30 |
+
# Add conditional edges from agent
|
31 |
+
builder.add_conditional_edges(
|
32 |
+
"agent",
|
33 |
+
should_continue,
|
34 |
+
{
|
35 |
+
"action": "action",
|
36 |
+
"write_memory": "write_memory",
|
37 |
+
}
|
38 |
+
)
|
39 |
+
|
40 |
+
# Connect action back to agent
|
41 |
+
builder.add_edge("action", "agent")
|
42 |
+
builder.add_edge("write_memory", END)
|
43 |
+
|
44 |
+
return builder
|
45 |
+
|
46 |
+
def create_agent_graph_without_memory():
|
47 |
+
"""Create an agent graph without memory persistence."""
|
48 |
+
builder = create_graph_builder()
|
49 |
+
return builder.compile()
|
50 |
+
|
51 |
+
class SQLiteCheckpointer:
|
52 |
+
"""Context manager for SQLite checkpointing."""
|
53 |
+
|
54 |
+
def __init__(self, db_path: str):
|
55 |
+
self.db_path = db_path
|
56 |
+
self.saver: Optional[AsyncSqliteSaver] = None
|
57 |
+
|
58 |
+
async def __aenter__(self) -> AsyncSqliteSaver:
|
59 |
+
"""Initialize and return the AsyncSqliteSaver."""
|
60 |
+
conn = await aiosqlite.connect(self.db_path)
|
61 |
+
self.saver = AsyncSqliteSaver(conn)
|
62 |
+
return self.saver
|
63 |
+
|
64 |
+
async def __aexit__(
|
65 |
+
self,
|
66 |
+
exc_type: Optional[Type[BaseException]],
|
67 |
+
exc_val: Optional[BaseException],
|
68 |
+
exc_tb: Optional[TracebackType],
|
69 |
+
) -> None:
|
70 |
+
"""Clean up the SQLite connection."""
|
71 |
+
if self.saver and hasattr(self.saver, 'conn'):
|
72 |
+
await self.saver.conn.close()
|
73 |
+
self.saver = None
|
74 |
+
|
75 |
+
def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer:
|
76 |
+
"""Create and return a SQLiteCheckpointer instance."""
|
77 |
+
return SQLiteCheckpointer(db_path)
|
78 |
+
|
79 |
+
# Initialize store for across-thread memory
|
80 |
+
across_thread_memory = InMemoryStore()
|
81 |
+
|
82 |
+
async def create_agent_graph(checkpointer: AsyncSqliteSaver):
|
83 |
+
"""Create an agent graph with memory persistence."""
|
84 |
+
builder = create_graph_builder()
|
85 |
+
# Compile with both SQLite checkpointer for within-thread memory
|
86 |
+
# and InMemoryStore for across-thread memory
|
87 |
+
graph = builder.compile(
|
88 |
+
checkpointer=checkpointer,
|
89 |
+
store=across_thread_memory
|
90 |
+
)
|
91 |
+
return graph
|
92 |
+
|
93 |
+
langgraph_studio_graph = create_agent_graph_without_memory()
|
94 |
+
|
95 |
+
# Export the graph builder functions
|
96 |
+
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
|
agent/utils/nodes.py
CHANGED
@@ -1,35 +1,199 @@
|
|
1 |
from langchain_openai import ChatOpenAI
|
2 |
-
from langchain_core.messages import SystemMessage
|
3 |
from langgraph.graph import END
|
4 |
from langgraph.prebuilt import ToolNode
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from agent.utils.tools import tool_belt
|
7 |
from agent.utils.state import AgentState
|
8 |
|
9 |
-
# Initialize LLM
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return {"messages": [response]}
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
last_message = state["messages"][-1]
|
33 |
-
if last_message
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return "action"
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain_openai import ChatOpenAI
|
2 |
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
3 |
from langgraph.graph import END
|
4 |
from langgraph.prebuilt import ToolNode
|
5 |
+
from langchain.memory import ConversationBufferMemory
|
6 |
+
from langchain_core.runnables.config import RunnableConfig
|
7 |
+
from langgraph.store.base import BaseStore
|
8 |
+
from typing import Literal
|
9 |
+
# from chainlit.logger import logger
|
10 |
+
|
11 |
|
12 |
from agent.utils.tools import tool_belt
|
13 |
from agent.utils.state import AgentState
|
14 |
|
15 |
+
# Initialize LLM for memory operations
|
16 |
+
model = ChatOpenAI(model="gpt-4", temperature=0)
|
17 |
+
|
18 |
+
# Define system prompt with memory
|
19 |
+
SYSTEM_PROMPT = """You are a Chief Joy Officer, an AI assistant focused on helping people find fun and enriching activities in London.
|
20 |
+
You have access to memory about the user's preferences and past interactions.
|
21 |
+
|
22 |
+
Here is what you remember about this user:
|
23 |
+
{memory}
|
24 |
+
|
25 |
+
Your core objectives are to:
|
26 |
+
1. Understand and remember user preferences and interests
|
27 |
+
2. Provide personalized activity recommendations based on their interests
|
28 |
+
3. Be engaging and enthusiastic while maintaining professionalism
|
29 |
+
4. Give clear, actionable suggestions
|
30 |
+
|
31 |
+
Key tools at your disposal:
|
32 |
+
- retrieve_context: For finding specific information about events and activities
|
33 |
+
- tavily_search: For general web searches about London activities
|
34 |
+
|
35 |
+
Always aim to provide value while being mindful of the user's time and interests."""
|
36 |
+
|
37 |
+
# Define memory creation/update prompt
|
38 |
+
MEMORY_UPDATE_PROMPT = """You are analyzing the conversation to update the user's profile and preferences.
|
39 |
+
|
40 |
+
CURRENT USER INFORMATION:
|
41 |
+
{memory}
|
42 |
+
|
43 |
+
INSTRUCTIONS:
|
44 |
+
1. Review the chat history carefully
|
45 |
+
2. Identify new information about the user, such as:
|
46 |
+
- Activity preferences (indoor/outdoor, cultural/sports, etc.)
|
47 |
+
- Specific interests (art, music, food, etc.)
|
48 |
+
- Location preferences in London
|
49 |
+
- Time/schedule constraints
|
50 |
+
- Past experiences with activities
|
51 |
+
- Budget considerations
|
52 |
+
3. Merge new information with existing memory
|
53 |
+
4. Format as a clear, bulleted list
|
54 |
+
5. If new information conflicts with existing memory, keep the most recent
|
55 |
+
|
56 |
+
Remember: Only include factual information directly stated by the user. Do not make assumptions.
|
57 |
+
|
58 |
+
Based on the conversation, please update the user information:"""
|
59 |
+
|
60 |
+
def get_last_human_message(state: AgentState):
|
61 |
+
"""Get the last human message from the state."""
|
62 |
+
for message in reversed(state["messages"]):
|
63 |
+
if isinstance(message, HumanMessage):
|
64 |
+
return message
|
65 |
+
return None
|
66 |
+
|
67 |
+
def call_model(state: AgentState, config: RunnableConfig, store: BaseStore):
|
68 |
+
"""Process messages using memory from the store."""
|
69 |
+
# Get the user ID from the config
|
70 |
+
user_id = config["configurable"].get("session_id", "default")
|
71 |
+
|
72 |
+
# Retrieve memory from the store
|
73 |
+
namespace = ("memory", user_id)
|
74 |
+
existing_memory = store.get(namespace, "user_memory")
|
75 |
+
|
76 |
+
# Extract memory content or use default
|
77 |
+
memory_content = existing_memory.value.get('memory') if existing_memory else "No previous information about this user."
|
78 |
+
|
79 |
+
# Create messages list with system prompt including memory
|
80 |
+
messages = [
|
81 |
+
SystemMessage(content=SYSTEM_PROMPT.format(memory=memory_content))
|
82 |
+
] + state["messages"]
|
83 |
+
tool_calling_model = model.bind_tools(tool_belt)
|
84 |
+
response = tool_calling_model.invoke(messages)
|
85 |
return {"messages": [response]}
|
86 |
|
87 |
+
def update_memory(state: AgentState, config: RunnableConfig, store: BaseStore):
|
88 |
+
"""Update user memory based on conversation."""
|
89 |
+
user_id = config["configurable"].get("session_id", "default")
|
90 |
+
namespace = ("memory", user_id)
|
91 |
+
existing_memory = store.get(namespace, "user_memory")
|
92 |
+
|
93 |
+
memory_content = existing_memory.value.get('memory') if existing_memory else "No previous information about this user."
|
94 |
+
|
95 |
+
update_prompt = MEMORY_UPDATE_PROMPT.format(memory=memory_content)
|
96 |
+
new_memory = model.invoke([
|
97 |
+
SystemMessage(content=update_prompt)
|
98 |
+
] + state["messages"])
|
99 |
+
|
100 |
+
store.put(namespace, "user_memory", {"memory": new_memory.content})
|
101 |
+
return state
|
102 |
|
103 |
+
def should_continue(state: AgentState) -> Literal["action", "write_memory"]:
|
104 |
+
"""Determine the next node in the graph."""
|
105 |
+
if not state["messages"]:
|
106 |
+
return END
|
107 |
+
|
108 |
last_message = state["messages"][-1]
|
109 |
+
if isinstance(last_message, list):
|
110 |
+
last_message = last_message[-1]
|
111 |
+
|
112 |
+
last_human_message = get_last_human_message(state)
|
113 |
+
|
114 |
+
# Handle tool calls
|
115 |
+
if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
|
116 |
return "action"
|
117 |
+
|
118 |
+
return "write_memory"
|
119 |
+
|
120 |
+
# # Handle memory operations for human messages
|
121 |
+
# if last_human_message:
|
122 |
+
|
123 |
+
# # Write memory for longer messages that might contain personal information
|
124 |
+
# if len(last_human_message.content.split()) > 3:
|
125 |
+
# return "write_memory"
|
126 |
+
|
127 |
+
# return END
|
128 |
+
|
129 |
+
# Define the memory creation prompt
|
130 |
+
MEMORY_CREATION_PROMPT = """"You are collecting information about the user to personalize your responses.
|
131 |
+
|
132 |
+
CURRENT USER INFORMATION:
|
133 |
+
{memory}
|
134 |
+
|
135 |
+
INSTRUCTIONS:
|
136 |
+
1. Review the chat history below carefully
|
137 |
+
2. Identify new information about the user, such as:
|
138 |
+
- Personal details (name, location)
|
139 |
+
- Preferences (likes, dislikes)
|
140 |
+
- Interests and hobbies
|
141 |
+
- Past experiences
|
142 |
+
- Goals or future plans
|
143 |
+
3. Merge any new information with existing memory
|
144 |
+
4. Format the memory as a clear, bulleted list
|
145 |
+
5. If new information conflicts with existing memory, keep the most recent version
|
146 |
+
|
147 |
+
Remember: Only include factual information directly stated by the user. Do not make assumptions or inferences.
|
148 |
+
|
149 |
+
Based on the chat history below, please update the user information:"""
|
150 |
+
|
151 |
+
async def write_memory(state: AgentState, config: RunnableConfig, store: BaseStore) -> AgentState:
|
152 |
+
"""Reflect on the chat history and save a memory to the store."""
|
153 |
+
|
154 |
+
# Get the session ID from config
|
155 |
+
session_id = config["configurable"].get("session_id", "default")
|
156 |
+
|
157 |
+
# Define the namespace for this user's memory
|
158 |
+
namespace = ("memory", session_id)
|
159 |
+
|
160 |
+
# Get existing memory using async interface
|
161 |
+
existing_memory = await store.aget(namespace, "user_memory")
|
162 |
+
memory_content = existing_memory.value.get('memory') if existing_memory else "No previous information about this user."
|
163 |
+
|
164 |
+
# Create system message with memory context
|
165 |
+
system_msg = SystemMessage(content=MEMORY_CREATION_PROMPT.format(memory=memory_content))
|
166 |
+
|
167 |
+
# Get messages and ensure we're working with the correct format
|
168 |
+
messages = state.get("messages", [])
|
169 |
+
if not messages:
|
170 |
+
return state
|
171 |
+
|
172 |
+
# Create memory using the model
|
173 |
+
new_memory = await model.ainvoke([system_msg] + messages)
|
174 |
+
|
175 |
+
# Store the updated memory using async interface
|
176 |
+
await store.aput(namespace, "user_memory", {"memory": new_memory.content})
|
177 |
+
|
178 |
+
|
179 |
+
return state
|
180 |
+
|
181 |
+
# Initialize tool node
|
182 |
+
tool_node = ToolNode(tool_belt)
|
183 |
+
|
184 |
+
# def route_message(state: MessagesState, config: RunnableConfig, store: BaseStore) -> Literal[END, "update_todos", "update_instructions", "update_profile"]:
|
185 |
+
|
186 |
+
# """Reflect on the memories and chat history to decide whether to update the memory collection."""
|
187 |
+
# message = state['messages'][-1]
|
188 |
+
# if len(message.tool_calls) ==0:
|
189 |
+
# return END
|
190 |
+
# else:
|
191 |
+
# tool_call = message.tool_calls[0]
|
192 |
+
# if tool_call['args']['update_type'] == "user":
|
193 |
+
# return "update_profile"
|
194 |
+
# elif tool_call['args']['update_type'] == "todo":
|
195 |
+
# return "update_todos"
|
196 |
+
# elif tool_call['args']['update_type'] == "instructions":
|
197 |
+
# return "update_instructions"
|
198 |
+
# else:
|
199 |
+
# raise ValueError
|
agent/utils/state.py
CHANGED
@@ -3,4 +3,4 @@ from langgraph.graph.message import add_messages
|
|
3 |
|
4 |
class AgentState(TypedDict):
|
5 |
messages: Annotated[list, add_messages]
|
6 |
-
context: list # Store retrieved context
|
|
|
3 |
|
4 |
class AgentState(TypedDict):
|
5 |
messages: Annotated[list, add_messages]
|
6 |
+
context: list # Store retrieved context
|
app.py
CHANGED
@@ -1,108 +1,99 @@
|
|
1 |
import uuid
|
2 |
-
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from langchain.schema.runnable.config import RunnableConfig
|
4 |
import chainlit as cl
|
5 |
-
from agent import
|
6 |
from agent.utils.state import AgentState
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
@cl.on_chat_start
|
9 |
async def on_chat_start():
|
10 |
-
|
11 |
-
session_id = str(uuid.uuid4())
|
12 |
-
cl.user_session.set("session_id", session_id)
|
13 |
-
|
14 |
-
# Initialize the conversation state with proper auth
|
15 |
-
cl.user_session.set("messages", [])
|
16 |
-
|
17 |
-
# Initialize config using stored session ID
|
18 |
-
config = RunnableConfig(
|
19 |
-
configurable={
|
20 |
-
"thread_id": session_id,
|
21 |
-
"sessionId": session_id
|
22 |
-
}
|
23 |
-
)
|
24 |
-
|
25 |
-
# Initialize empty state with auth
|
26 |
try:
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
|
|
|
|
31 |
except Exception as e:
|
32 |
-
print(f"Error
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
|
39 |
@cl.on_message
|
40 |
async def on_message(message: cl.Message):
|
|
|
|
|
41 |
session_id = cl.user_session.get("session_id")
|
42 |
-
print(f"Session ID: {session_id}")
|
43 |
if not session_id:
|
44 |
session_id = str(uuid.uuid4())
|
45 |
cl.user_session.set("session_id", session_id)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
"thread_id": session_id,
|
50 |
-
"checkpoint_ns": "default_namespace",
|
51 |
-
"sessionId": session_id
|
52 |
-
}
|
53 |
-
)
|
54 |
|
55 |
-
# Try to retrieve previous conversation state
|
56 |
try:
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
except Exception as e:
|
66 |
-
print(f"Error
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
final_answer = cl.Message(content="")
|
72 |
-
await final_answer.send()
|
73 |
-
|
74 |
-
loading_msg = None # Initialize reference to loading message
|
75 |
-
|
76 |
-
# Stream the response
|
77 |
-
async for chunk in graph.astream(
|
78 |
-
AgentState(messages=current_messages, context=[]),
|
79 |
-
config=RunnableConfig(
|
80 |
-
configurable={
|
81 |
-
"thread_id": session_id,
|
82 |
-
}
|
83 |
-
)
|
84 |
-
):
|
85 |
-
for node, values in chunk.items():
|
86 |
-
if node == "retrieve":
|
87 |
-
loading_msg = cl.Message(content="🔍 Searching knowledge base...", author="System")
|
88 |
-
await loading_msg.send()
|
89 |
-
elif values.get("messages"):
|
90 |
-
last_message = values["messages"][-1]
|
91 |
-
# Check for tool calls in additional_kwargs
|
92 |
-
if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
|
93 |
-
tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
|
94 |
-
if loading_msg:
|
95 |
-
await loading_msg.remove()
|
96 |
-
loading_msg = cl.Message(
|
97 |
-
content=f"🔍 Using {tool_name}...",
|
98 |
-
author="Tool"
|
99 |
-
)
|
100 |
-
await loading_msg.send()
|
101 |
-
# Only stream AI messages, skip tool outputs
|
102 |
-
elif isinstance(last_message, AIMessage):
|
103 |
-
if loading_msg:
|
104 |
-
await loading_msg.remove()
|
105 |
-
loading_msg = None
|
106 |
-
await final_answer.stream_token(last_message.content)
|
107 |
-
|
108 |
-
await final_answer.send()
|
|
|
1 |
import uuid
|
2 |
+
from langchain_core.messages import HumanMessage, AIMessage, AIMessageChunk
|
3 |
from langchain.schema.runnable.config import RunnableConfig
|
4 |
import chainlit as cl
|
5 |
+
from agent import create_agent_graph, get_checkpointer
|
6 |
from agent.utils.state import AgentState
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
|
10 |
+
SHORT_TERM_MEMORY_DB_PATH = "data/short_term.db"
|
11 |
+
|
12 |
+
os.makedirs(os.path.dirname(SHORT_TERM_MEMORY_DB_PATH), exist_ok=True)
|
13 |
|
14 |
@cl.on_chat_start
|
15 |
async def on_chat_start():
|
16 |
+
"""Initialize the chat session"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
try:
|
18 |
+
# Generate and store a session ID
|
19 |
+
session_id = str(uuid.uuid4())
|
20 |
+
cl.user_session.set("session_id", session_id)
|
21 |
+
|
22 |
+
# Initialize empty message history
|
23 |
+
cl.user_session.set("message_history", [])
|
24 |
+
|
25 |
+
welcome_message = cl.Message(
|
26 |
+
content="Hello! I'm your chief joy officer, here to help you with finding fun things to do in London!",
|
27 |
+
author="Assistant"
|
28 |
)
|
29 |
+
await welcome_message.send()
|
30 |
+
|
31 |
except Exception as e:
|
32 |
+
print(f"Error in chat initialization: {str(e)}")
|
33 |
+
error_message = cl.Message(
|
34 |
+
content="I apologize, but I encountered an error during initialization. Please try refreshing the page.",
|
35 |
+
author="System"
|
36 |
+
)
|
37 |
+
await error_message.send()
|
38 |
|
39 |
@cl.on_message
|
40 |
async def on_message(message: cl.Message):
|
41 |
+
"""Handle incoming messages and stream responses"""
|
42 |
+
# Get or create session ID
|
43 |
session_id = cl.user_session.get("session_id")
|
|
|
44 |
if not session_id:
|
45 |
session_id = str(uuid.uuid4())
|
46 |
cl.user_session.set("session_id", session_id)
|
47 |
|
48 |
+
# Initialize response message
|
49 |
+
msg = cl.Message(content="")
|
|
|
|
|
|
|
|
|
|
|
50 |
|
|
|
51 |
try:
|
52 |
+
async with cl.Step(type="run"):
|
53 |
+
async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
|
54 |
+
# Create graph with memory
|
55 |
+
graph = await create_agent_graph(saver)
|
56 |
+
|
57 |
+
# Get message history and add current message
|
58 |
+
message_history = cl.user_session.get("message_history", [])
|
59 |
+
current_message = HumanMessage(content=message.content)
|
60 |
+
message_history.append(current_message)
|
61 |
+
|
62 |
+
# Create current state
|
63 |
+
current_state = AgentState(
|
64 |
+
messages=message_history,
|
65 |
+
context=cl.user_session.get("last_context", [])
|
66 |
+
)
|
67 |
+
|
68 |
+
# Stream the response
|
69 |
+
async for chunk in graph.astream(
|
70 |
+
current_state,
|
71 |
+
config={"configurable": {"thread_id": session_id}},
|
72 |
+
stream_mode="messages"
|
73 |
+
):
|
74 |
+
# Handle different node outputs
|
75 |
+
if isinstance(chunk[0], AIMessageChunk):
|
76 |
+
await msg.stream_token(chunk[0].content)
|
77 |
+
elif isinstance(chunk[0], AIMessage):
|
78 |
+
if chunk[0] not in message_history:
|
79 |
+
message_history.append(chunk[0])
|
80 |
+
|
81 |
+
# Get final state
|
82 |
+
final_state = await graph.aget_state(
|
83 |
+
config={"configurable": {"thread_id": session_id}}
|
84 |
+
)
|
85 |
+
|
86 |
+
# Update session state
|
87 |
+
if final_state:
|
88 |
+
cl.user_session.set("message_history", message_history)
|
89 |
+
cl.user_session.set("last_context", final_state.values.get("context", []))
|
90 |
+
|
91 |
+
# Send the final message
|
92 |
+
await msg.send()
|
93 |
+
|
94 |
except Exception as e:
|
95 |
+
print(f"Error in message handler: {str(e)}")
|
96 |
+
await cl.Message(
|
97 |
+
content="I apologize, but I encountered an error processing your message. Please try again.",
|
98 |
+
author="System"
|
99 |
+
).send()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docker-compose.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.8"
|
2 |
+
|
3 |
+
services:
|
4 |
+
qdrant:
|
5 |
+
image: qdrant/qdrant:latest
|
6 |
+
ports:
|
7 |
+
- "6333:6333"
|
8 |
+
volumes:
|
9 |
+
- ./data/long_term_memory:/qdrant/storage
|
10 |
+
restart: unless-stopped
|
11 |
+
|
12 |
+
app:
|
13 |
+
build: .
|
14 |
+
ports:
|
15 |
+
- "8000:8000"
|
16 |
+
volumes:
|
17 |
+
- .:/app
|
18 |
+
- ./data/short_term_memory:/app/data
|
19 |
+
environment:
|
20 |
+
- QDRANT_URL=http://qdrant:6333
|
21 |
+
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
22 |
+
depends_on:
|
23 |
+
- qdrant
|
langgraph.json
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
{
|
2 |
"dependencies": ".",
|
3 |
"graphs": {
|
4 |
-
"agent": "agent:
|
5 |
},
|
6 |
"env": ".env"
|
7 |
}
|
|
|
1 |
{
|
2 |
"dependencies": ".",
|
3 |
"graphs": {
|
4 |
+
"agent": "agent:langgraph_studio_graph"
|
5 |
},
|
6 |
"env": ".env"
|
7 |
}
|
studio_example_message.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"role": "user",
|
3 |
+
"content": "I like sailing a lot, tell me about some activities I can do and remember this fact about me"
|
4 |
+
}
|
tools.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from langchain_core.tools import tool
|
2 |
-
from langchain_community.tools.tavily_search import TavilySearchResults
|
3 |
-
from rag import create_rag_pipeline, add_urls_to_vectorstore
|
4 |
-
|
5 |
-
# Initialize RAG pipeline
|
6 |
-
rag_components = create_rag_pipeline(collection_name="london_events")
|
7 |
-
|
8 |
-
# Add some initial URLs to the vector store
|
9 |
-
urls = [
|
10 |
-
"https://www.timeout.com/london/things-to-do-in-london-this-weekend",
|
11 |
-
"https://www.timeout.com/london/london-events-in-march"
|
12 |
-
]
|
13 |
-
add_urls_to_vectorstore(
|
14 |
-
rag_components["vector_store"],
|
15 |
-
rag_components["text_splitter"],
|
16 |
-
urls
|
17 |
-
)
|
18 |
-
|
19 |
-
@tool
|
20 |
-
def retrieve_context(query: str) -> list[str]:
|
21 |
-
"""Searches the knowledge base for relevant information about events and activities. Use this when you need specific details about events."""
|
22 |
-
return [doc.page_content for doc in rag_components["retriever"].get_relevant_documents(query)]
|
23 |
-
|
24 |
-
# Initialize Tavily search tool
|
25 |
-
tavily_tool = TavilySearchResults(max_results=5)
|
26 |
-
|
27 |
-
# Create tool belt
|
28 |
-
tool_belt = [tavily_tool, retrieve_context]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|