ric9176 commited on
Commit
d6e8585
·
1 Parent(s): e3030c3

feat: fix and test chainlit interface for cross thread memory

Browse files
.gitignore CHANGED
@@ -1,4 +1,6 @@
1
  __pycache__/
2
  .chainlit/
3
  .venv/
4
- .env
 
 
 
1
  __pycache__/
2
  .chainlit/
3
  .venv/
4
+ .env
5
+ .langgraph_api/
6
+ data/
.langgraph_api/.langgraph_checkpoint.1.pckl DELETED
Binary file (957 kB)
 
.langgraph_api/.langgraph_checkpoint.2.pckl DELETED
Binary file (280 kB)
 
.langgraph_api/.langgraph_ops.pckl DELETED
Binary file (167 kB)
 
.langgraph_api/.langgraph_retry_counter.pckl DELETED
Binary file (83 Bytes)
 
.langgraph_api/store.pckl DELETED
Binary file (608 Bytes)
 
.langgraph_api/store.vectors.pckl DELETED
Binary file (90 Bytes)
 
agent/graph.py CHANGED
@@ -26,7 +26,6 @@ def create_graph_builder():
26
  # Set entry point
27
  builder.set_entry_point("agent")
28
 
29
- # builder.add_edge("agent", "write_memory")
30
 
31
  # Add conditional edges from agent
32
  builder.add_conditional_edges(
 
26
  # Set entry point
27
  builder.set_entry_point("agent")
28
 
 
29
 
30
  # Add conditional edges from agent
31
  builder.add_conditional_edges(
app.py CHANGED
@@ -1,151 +1,99 @@
1
  import uuid
2
- from langchain_core.messages import HumanMessage, AIMessage
3
  from langchain.schema.runnable.config import RunnableConfig
4
  import chainlit as cl
5
- from agent import create_agent_graph, create_agent_graph_without_memory, get_checkpointer
6
  from agent.utils.state import AgentState
7
  import os
8
  import json
9
 
10
- # Path to SQLite database for short-term memory
11
  SHORT_TERM_MEMORY_DB_PATH = "data/short_term.db"
12
 
13
- # Ensure the data directory exists
14
  os.makedirs(os.path.dirname(SHORT_TERM_MEMORY_DB_PATH), exist_ok=True)
15
 
16
  @cl.on_chat_start
17
  async def on_chat_start():
18
- # Generate and store a session ID
19
- session_id = str(uuid.uuid4())
20
- cl.user_session.set("session_id", session_id)
21
-
22
- # Initialize empty message history
23
- cl.user_session.set("message_history", [])
24
-
25
- # Initialize config using stored session ID
26
- config = RunnableConfig(
27
- configurable={
28
- "thread_id": session_id,
29
- "session_id": session_id,
30
- "checkpoint_ns": session_id
31
- }
32
- )
33
-
34
- # Initialize empty state with auth
35
  try:
36
- async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
37
- graph = await create_agent_graph(saver)
38
- initial_state = AgentState(
39
- messages=[],
40
- context=[]
41
- )
42
-
43
- await graph.ainvoke(initial_state, config=config)
44
-
45
- # Store initial state
46
- cl.user_session.set("last_state", {
47
- "messages": [],
48
- "context": []
49
- })
50
  except Exception as e:
51
- print(f"Error initializing state: {str(e)}")
52
-
53
- await cl.Message(
54
- content="Hello! I'm your chief joy officer, here to help you with finding fun things to do in London!",
55
- author="Assistant"
56
- ).send()
57
 
58
  @cl.on_message
59
  async def on_message(message: cl.Message):
 
60
  # Get or create session ID
61
  session_id = cl.user_session.get("session_id")
62
  if not session_id:
63
  session_id = str(uuid.uuid4())
64
  cl.user_session.set("session_id", session_id)
65
 
66
- print(f"Session ID: {session_id}")
67
-
68
- # Get message history
69
- message_history = cl.user_session.get("message_history", [])
70
-
71
- # Add new message to history
72
- current_message = HumanMessage(content=message.content)
73
- message_history.append(current_message)
74
- cl.user_session.set("message_history", message_history)
75
-
76
- config = RunnableConfig(
77
- configurable={
78
- "thread_id": session_id,
79
- "session_id": session_id,
80
- "checkpoint_ns": session_id
81
- }
82
- )
83
 
84
  try:
85
- async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
86
- # Create graph with memory
87
- graph = await create_agent_graph(saver)
88
-
89
- # Get the last state or create new one
90
- last_state_dict = cl.user_session.get("last_state", {"messages": [], "context": []})
91
-
92
- # Create new state with current message history
93
- current_state = AgentState(
94
- messages=message_history,
95
- context=last_state_dict.get("context", [])
96
- )
97
-
98
- # Setup callback handler and final answer message
99
- cb = cl.LangchainCallbackHandler()
100
- final_answer = cl.Message(content="")
101
- await final_answer.send()
102
-
103
- loading_msg = None # Initialize reference to loading message
104
- last_state = None # Track the final state
105
-
106
- # Stream the response
107
- async for chunk in graph.astream(
108
- current_state,
109
- config=config
110
- ):
111
- for node, values in chunk.items():
112
- if node == "retrieve":
113
- if loading_msg:
114
- await loading_msg.remove()
115
- loading_msg = cl.Message(content="🔍 Searching knowledge base...", author="System")
116
- await loading_msg.send()
117
- elif values.get("messages"):
118
- last_message = values["messages"][-1]
119
- # Check for tool calls in additional_kwargs
120
- if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
121
- tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
122
- if loading_msg:
123
- await loading_msg.remove()
124
- loading_msg = cl.Message(
125
- content=f"🔍 Using {tool_name}...",
126
- author="Tool"
127
- )
128
- await loading_msg.send()
129
- # Only stream AI messages, skip tool outputs
130
- elif isinstance(last_message, AIMessage):
131
- if loading_msg:
132
- await loading_msg.remove()
133
- loading_msg = None
134
- await final_answer.stream_token(last_message.content)
135
- # Add AI message to history
136
- message_history.append(last_message)
137
- cl.user_session.set("message_history", message_history)
138
- # Update last state
139
- last_state = values
140
-
141
- # Update the last state as a serializable dict
142
- if last_state:
143
- cl.user_session.set("last_state", {
144
- "messages": [msg.content for msg in message_history],
145
- "context": last_state.get("context", [])
146
- })
147
- await final_answer.send()
148
-
149
  except Exception as e:
150
  print(f"Error in message handler: {str(e)}")
151
- await cl.Message(content="I apologize, but I encountered an error processing your message. Please try again.").send()
 
 
 
 
1
  import uuid
2
+ from langchain_core.messages import HumanMessage, AIMessage, AIMessageChunk
3
  from langchain.schema.runnable.config import RunnableConfig
4
  import chainlit as cl
5
+ from agent import create_agent_graph, get_checkpointer
6
  from agent.utils.state import AgentState
7
  import os
8
  import json
9
 
 
10
  SHORT_TERM_MEMORY_DB_PATH = "data/short_term.db"
11
 
 
12
  os.makedirs(os.path.dirname(SHORT_TERM_MEMORY_DB_PATH), exist_ok=True)
13
 
14
  @cl.on_chat_start
15
  async def on_chat_start():
16
+ """Initialize the chat session"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
+ # Generate and store a session ID
19
+ session_id = str(uuid.uuid4())
20
+ cl.user_session.set("session_id", session_id)
21
+
22
+ # Initialize empty message history
23
+ cl.user_session.set("message_history", [])
24
+
25
+ welcome_message = cl.Message(
26
+ content="Hello! I'm your chief joy officer, here to help you with finding fun things to do in London!",
27
+ author="Assistant"
28
+ )
29
+ await welcome_message.send()
30
+
 
31
  except Exception as e:
32
+ print(f"Error in chat initialization: {str(e)}")
33
+ error_message = cl.Message(
34
+ content="I apologize, but I encountered an error during initialization. Please try refreshing the page.",
35
+ author="System"
36
+ )
37
+ await error_message.send()
38
 
39
  @cl.on_message
40
  async def on_message(message: cl.Message):
41
+ """Handle incoming messages and stream responses"""
42
  # Get or create session ID
43
  session_id = cl.user_session.get("session_id")
44
  if not session_id:
45
  session_id = str(uuid.uuid4())
46
  cl.user_session.set("session_id", session_id)
47
 
48
+ # Initialize response message
49
+ msg = cl.Message(content="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  try:
52
+ async with cl.Step(type="run"):
53
+ async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
54
+ # Create graph with memory
55
+ graph = await create_agent_graph(saver)
56
+
57
+ # Get message history and add current message
58
+ message_history = cl.user_session.get("message_history", [])
59
+ current_message = HumanMessage(content=message.content)
60
+ message_history.append(current_message)
61
+
62
+ # Create current state
63
+ current_state = AgentState(
64
+ messages=message_history,
65
+ context=cl.user_session.get("last_context", [])
66
+ )
67
+
68
+ # Stream the response
69
+ async for chunk in graph.astream(
70
+ current_state,
71
+ config={"configurable": {"thread_id": session_id}},
72
+ stream_mode="messages"
73
+ ):
74
+ # Handle different node outputs
75
+ if isinstance(chunk[0], AIMessageChunk):
76
+ await msg.stream_token(chunk[0].content)
77
+ elif isinstance(chunk[0], AIMessage):
78
+ if chunk[0] not in message_history:
79
+ message_history.append(chunk[0])
80
+
81
+ # Get final state
82
+ final_state = await graph.aget_state(
83
+ config={"configurable": {"thread_id": session_id}}
84
+ )
85
+
86
+ # Update session state
87
+ if final_state:
88
+ cl.user_session.set("message_history", message_history)
89
+ cl.user_session.set("last_context", final_state.values.get("context", []))
90
+
91
+ # Send the final message
92
+ await msg.send()
93
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
  print(f"Error in message handler: {str(e)}")
96
+ await cl.Message(
97
+ content="I apologize, but I encountered an error processing your message. Please try again.",
98
+ author="System"
99
+ ).send()
data/short_term.db-shm DELETED
Binary file (32.8 kB)