Liss, Alex (NYC-HUG) commited on
Commit
79a1c17
·
1 Parent(s): 9afe931

WIP fix to memory almost there

Browse files
Files changed (2) hide show
  1. gradio_agent.py +71 -10
  2. z_utils/zep_test.py +29 -0
gradio_agent.py CHANGED
@@ -6,10 +6,14 @@ import os
6
  from langchain.agents import AgentExecutor, create_react_agent
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain.tools import Tool
 
 
9
  from langchain_core.runnables.history import RunnableWithMessageHistory
10
  from langchain_neo4j import Neo4jChatMessageHistory
11
  from langchain.callbacks.manager import CallbackManager
12
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
13
 
14
  # Import Gradio-specific modules directly
15
  from gradio_llm import llm
@@ -38,6 +42,9 @@ chat_prompt = ChatPromptTemplate.from_messages(
38
  # Create a non-streaming LLM for the agent
39
  from langchain_openai import ChatOpenAI
40
 
 
 
 
41
  # Get API key from environment only (no Streamlit)
42
  def get_api_key(key_name):
43
  """Get API key from environment variables only (no Streamlit)"""
@@ -58,11 +65,12 @@ if not OPENAI_API_KEY:
58
  else:
59
  raise ValueError(f"OPENAI_API_KEY not found in environment variables")
60
 
 
61
  agent_llm = ChatOpenAI(
62
  openai_api_key=OPENAI_API_KEY,
63
  model=OPENAI_MODEL,
64
  temperature=0.1,
65
- streaming=True # Enable streaming for agent
66
  )
67
 
68
  movie_chat = chat_prompt | llm | StrOutputParser()
@@ -123,10 +131,16 @@ Do NOT use for any 49ers-specific questions.""",
123
  )
124
  ]
125
 
 
 
126
  # Create the memory manager
127
  def get_memory(session_id):
128
- """Get the chat history from Neo4j for the given session"""
129
- return Neo4jChatMessageHistory(session_id=session_id, graph=graph)
 
 
 
 
130
 
131
  # Create the agent prompt
132
  agent_prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT)
@@ -149,6 +163,43 @@ chat_agent = RunnableWithMessageHistory(
149
  history_messages_key="chat_history",
150
  )
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def generate_response(user_input, session_id=None):
153
  """
154
  Generate a response using the agent and tools
@@ -167,17 +218,27 @@ def generate_response(user_input, session_id=None):
167
  if not session_id:
168
  session_id = get_session_id()
169
  print(f'Generated new session ID: {session_id}')
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Add retry logic
172
  max_retries = 3
173
  for attempt in range(max_retries):
174
  try:
175
- print('Invoking chat_agent...')
176
- response = chat_agent.invoke(
177
- {"input": user_input},
178
- {"configurable": {"session_id": session_id}},
179
- )
180
- print(f'Raw response from chat_agent: {response}')
181
 
182
  # Extract the output and format it for Streamlit
183
  if isinstance(response, dict):
 
6
  from langchain.agents import AgentExecutor, create_react_agent
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain.tools import Tool
9
+ from langchain_community.chat_message_histories import ZepCloudChatMessageHistory
10
+ from langchain_community.memory.zep_cloud_memory import ZepCloudMemory
11
  from langchain_core.runnables.history import RunnableWithMessageHistory
12
  from langchain_neo4j import Neo4jChatMessageHistory
13
  from langchain.callbacks.manager import CallbackManager
14
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
+ from langchain_community.chat_message_histories import ChatMessageHistory
16
+ from langchain.memory import ConversationBufferMemory
17
 
18
  # Import Gradio-specific modules directly
19
  from gradio_llm import llm
 
42
  # Create a non-streaming LLM for the agent
43
  from langchain_openai import ChatOpenAI
44
 
45
+ # Import Zep client
46
+ from zep_cloud.client import Zep
47
+
48
  # Get API key from environment only (no Streamlit)
49
  def get_api_key(key_name):
50
  """Get API key from environment variables only (no Streamlit)"""
 
65
  else:
66
  raise ValueError(f"OPENAI_API_KEY not found in environment variables")
67
 
68
+
69
  agent_llm = ChatOpenAI(
70
  openai_api_key=OPENAI_API_KEY,
71
  model=OPENAI_MODEL,
72
  temperature=0.1,
73
+ streaming=True, # Enable streaming for agent
74
  )
75
 
76
  movie_chat = chat_prompt | llm | StrOutputParser()
 
131
  )
132
  ]
133
 
134
+ session_id = "241b3478c7634492abee9f178b5341cb"
135
+
136
  # Create the memory manager
137
  def get_memory(session_id):
138
+ """Get the chat history from Zep for the given session"""
139
+ return ZepCloudChatMessageHistory(
140
+ session_id=session_id,
141
+ api_key=os.environ.get("ZEP_API_KEY")
142
+ # No memory_type parameter
143
+ )
144
 
145
  # Create the agent prompt
146
  agent_prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT)
 
163
  history_messages_key="chat_history",
164
  )
165
 
166
+ # Create a function to initialize memory with Zep history
167
+ def initialize_memory_from_zep(session_id):
168
+ """Initialize a LangChain memory object with history from Zep"""
169
+ try:
170
+ # Get history from Zep
171
+ zep = Zep(api_key=os.environ.get("ZEP_API_KEY"))
172
+ memory = zep.memory.get(session_id=session_id)
173
+
174
+ # Create a conversation memory with the history
175
+ conversation_memory = ConversationBufferMemory(
176
+ memory_key="chat_history",
177
+ return_messages=True
178
+ )
179
+
180
+ if memory and memory.messages:
181
+ print(f"Loading {len(memory.messages)} messages from Zep for session {session_id}")
182
+
183
+ # Add messages to the conversation memory
184
+ for msg in memory.messages:
185
+ if msg.role_type == "user":
186
+ conversation_memory.chat_memory.add_user_message(msg.content)
187
+ elif msg.role_type == "assistant":
188
+ conversation_memory.chat_memory.add_ai_message(msg.content)
189
+
190
+ print("Successfully loaded message history from Zep")
191
+ else:
192
+ print("No message history found in Zep, starting fresh")
193
+
194
+ return conversation_memory
195
+ except Exception as e:
196
+ print(f"Error loading history from Zep: {e}")
197
+ # Return empty memory if there's an error
198
+ return ConversationBufferMemory(
199
+ memory_key="chat_history",
200
+ return_messages=True
201
+ )
202
+
203
  def generate_response(user_input, session_id=None):
204
  """
205
  Generate a response using the agent and tools
 
218
  if not session_id:
219
  session_id = get_session_id()
220
  print(f'Generated new session ID: {session_id}')
221
+
222
+ # Initialize memory with Zep history
223
+ memory = initialize_memory_from_zep(session_id)
224
+
225
+ # Create an agent executor with memory for this session
226
+ session_agent_executor = AgentExecutor(
227
+ agent=agent,
228
+ tools=tools,
229
+ verbose=True,
230
+ memory=memory, # Use the memory we initialized
231
+ handle_parsing_errors=True,
232
+ max_iterations=5
233
+ )
234
+
235
  # Add retry logic
236
  max_retries = 3
237
  for attempt in range(max_retries):
238
  try:
239
+ print('Invoking session agent executor...')
240
+ # The agent will now have access to the loaded history
241
+ response = session_agent_executor.invoke({"input": user_input})
 
 
 
242
 
243
  # Extract the output and format it for Streamlit
244
  if isinstance(response, dict):
z_utils/zep_test.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import json
7
  from dotenv import load_dotenv
8
  from zep_cloud.client import Zep
 
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
@@ -41,6 +42,34 @@ def retrieve_chat_history(session_id):
41
  print(f"Error retrieving chat history: {e}")
42
  return None
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def main():
45
  print(f"Retrieving chat history for session ID: {SESSION_ID}")
46
 
 
6
  import json
7
  from dotenv import load_dotenv
8
  from zep_cloud.client import Zep
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
 
11
  # Load environment variables from .env file
12
  load_dotenv()
 
42
  print(f"Error retrieving chat history: {e}")
43
  return None
44
 
45
+ def get_zep_history(session_id):
46
+ """
47
+ Retrieve chat history directly from Zep using the client.
48
+
49
+ Args:
50
+ session_id (str): The session ID to retrieve history for
51
+
52
+ Returns:
53
+ list: Formatted messages for LangChain
54
+ """
55
+ try:
56
+ zep = Zep(api_key=os.environ.get("ZEP_API_KEY"))
57
+ memory = zep.memory.get(session_id=session_id)
58
+
59
+ # Convert Zep messages to LangChain format
60
+ formatted_messages = []
61
+ if memory and memory.messages:
62
+ for msg in memory.messages:
63
+ if msg.role_type == "user":
64
+ formatted_messages.append(HumanMessage(content=msg.content))
65
+ elif msg.role_type == "assistant":
66
+ formatted_messages.append(AIMessage(content=msg.content))
67
+
68
+ return formatted_messages
69
+ except Exception as e:
70
+ print(f"Error retrieving Zep history: {e}")
71
+ return []
72
+
73
  def main():
74
  print(f"Retrieving chat history for session ID: {SESSION_ID}")
75