Spaces:
No application file
No application file
Liss, Alex (NYC-HUG)
commited on
Commit
·
79a1c17
1
Parent(s):
9afe931
WIP fix to memory almost there
Browse files- gradio_agent.py +71 -10
- z_utils/zep_test.py +29 -0
gradio_agent.py
CHANGED
@@ -6,10 +6,14 @@ import os
|
|
6 |
from langchain.agents import AgentExecutor, create_react_agent
|
7 |
from langchain_core.prompts import PromptTemplate
|
8 |
from langchain.tools import Tool
|
|
|
|
|
9 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
10 |
from langchain_neo4j import Neo4jChatMessageHistory
|
11 |
from langchain.callbacks.manager import CallbackManager
|
12 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
|
|
13 |
|
14 |
# Import Gradio-specific modules directly
|
15 |
from gradio_llm import llm
|
@@ -38,6 +42,9 @@ chat_prompt = ChatPromptTemplate.from_messages(
|
|
38 |
# Create a non-streaming LLM for the agent
|
39 |
from langchain_openai import ChatOpenAI
|
40 |
|
|
|
|
|
|
|
41 |
# Get API key from environment only (no Streamlit)
|
42 |
def get_api_key(key_name):
|
43 |
"""Get API key from environment variables only (no Streamlit)"""
|
@@ -58,11 +65,12 @@ if not OPENAI_API_KEY:
|
|
58 |
else:
|
59 |
raise ValueError(f"OPENAI_API_KEY not found in environment variables")
|
60 |
|
|
|
61 |
agent_llm = ChatOpenAI(
|
62 |
openai_api_key=OPENAI_API_KEY,
|
63 |
model=OPENAI_MODEL,
|
64 |
temperature=0.1,
|
65 |
-
streaming=True # Enable streaming for agent
|
66 |
)
|
67 |
|
68 |
movie_chat = chat_prompt | llm | StrOutputParser()
|
@@ -123,10 +131,16 @@ Do NOT use for any 49ers-specific questions.""",
|
|
123 |
)
|
124 |
]
|
125 |
|
|
|
|
|
126 |
# Create the memory manager
|
127 |
def get_memory(session_id):
|
128 |
-
"""Get the chat history from
|
129 |
-
return
|
|
|
|
|
|
|
|
|
130 |
|
131 |
# Create the agent prompt
|
132 |
agent_prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT)
|
@@ -149,6 +163,43 @@ chat_agent = RunnableWithMessageHistory(
|
|
149 |
history_messages_key="chat_history",
|
150 |
)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def generate_response(user_input, session_id=None):
|
153 |
"""
|
154 |
Generate a response using the agent and tools
|
@@ -167,17 +218,27 @@ def generate_response(user_input, session_id=None):
|
|
167 |
if not session_id:
|
168 |
session_id = get_session_id()
|
169 |
print(f'Generated new session ID: {session_id}')
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
# Add retry logic
|
172 |
max_retries = 3
|
173 |
for attempt in range(max_retries):
|
174 |
try:
|
175 |
-
print('Invoking
|
176 |
-
|
177 |
-
|
178 |
-
{"configurable": {"session_id": session_id}},
|
179 |
-
)
|
180 |
-
print(f'Raw response from chat_agent: {response}')
|
181 |
|
182 |
# Extract the output and format it for Streamlit
|
183 |
if isinstance(response, dict):
|
|
|
6 |
from langchain.agents import AgentExecutor, create_react_agent
|
7 |
from langchain_core.prompts import PromptTemplate
|
8 |
from langchain.tools import Tool
|
9 |
+
from langchain_community.chat_message_histories import ZepCloudChatMessageHistory
|
10 |
+
from langchain_community.memory.zep_cloud_memory import ZepCloudMemory
|
11 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
12 |
from langchain_neo4j import Neo4jChatMessageHistory
|
13 |
from langchain.callbacks.manager import CallbackManager
|
14 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
15 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
16 |
+
from langchain.memory import ConversationBufferMemory
|
17 |
|
18 |
# Import Gradio-specific modules directly
|
19 |
from gradio_llm import llm
|
|
|
42 |
# Create a non-streaming LLM for the agent
|
43 |
from langchain_openai import ChatOpenAI
|
44 |
|
45 |
+
# Import Zep client
|
46 |
+
from zep_cloud.client import Zep
|
47 |
+
|
48 |
# Get API key from environment only (no Streamlit)
|
49 |
def get_api_key(key_name):
|
50 |
"""Get API key from environment variables only (no Streamlit)"""
|
|
|
65 |
else:
|
66 |
raise ValueError(f"OPENAI_API_KEY not found in environment variables")
|
67 |
|
68 |
+
|
69 |
agent_llm = ChatOpenAI(
|
70 |
openai_api_key=OPENAI_API_KEY,
|
71 |
model=OPENAI_MODEL,
|
72 |
temperature=0.1,
|
73 |
+
streaming=True, # Enable streaming for agent
|
74 |
)
|
75 |
|
76 |
movie_chat = chat_prompt | llm | StrOutputParser()
|
|
|
131 |
)
|
132 |
]
|
133 |
|
134 |
+
session_id = "241b3478c7634492abee9f178b5341cb"
|
135 |
+
|
136 |
# Create the memory manager
|
137 |
def get_memory(session_id):
|
138 |
+
"""Get the chat history from Zep for the given session"""
|
139 |
+
return ZepCloudChatMessageHistory(
|
140 |
+
session_id=session_id,
|
141 |
+
api_key=os.environ.get("ZEP_API_KEY")
|
142 |
+
# No memory_type parameter
|
143 |
+
)
|
144 |
|
145 |
# Create the agent prompt
|
146 |
agent_prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT)
|
|
|
163 |
history_messages_key="chat_history",
|
164 |
)
|
165 |
|
166 |
+
# Create a function to initialize memory with Zep history
|
167 |
+
def initialize_memory_from_zep(session_id):
|
168 |
+
"""Initialize a LangChain memory object with history from Zep"""
|
169 |
+
try:
|
170 |
+
# Get history from Zep
|
171 |
+
zep = Zep(api_key=os.environ.get("ZEP_API_KEY"))
|
172 |
+
memory = zep.memory.get(session_id=session_id)
|
173 |
+
|
174 |
+
# Create a conversation memory with the history
|
175 |
+
conversation_memory = ConversationBufferMemory(
|
176 |
+
memory_key="chat_history",
|
177 |
+
return_messages=True
|
178 |
+
)
|
179 |
+
|
180 |
+
if memory and memory.messages:
|
181 |
+
print(f"Loading {len(memory.messages)} messages from Zep for session {session_id}")
|
182 |
+
|
183 |
+
# Add messages to the conversation memory
|
184 |
+
for msg in memory.messages:
|
185 |
+
if msg.role_type == "user":
|
186 |
+
conversation_memory.chat_memory.add_user_message(msg.content)
|
187 |
+
elif msg.role_type == "assistant":
|
188 |
+
conversation_memory.chat_memory.add_ai_message(msg.content)
|
189 |
+
|
190 |
+
print("Successfully loaded message history from Zep")
|
191 |
+
else:
|
192 |
+
print("No message history found in Zep, starting fresh")
|
193 |
+
|
194 |
+
return conversation_memory
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error loading history from Zep: {e}")
|
197 |
+
# Return empty memory if there's an error
|
198 |
+
return ConversationBufferMemory(
|
199 |
+
memory_key="chat_history",
|
200 |
+
return_messages=True
|
201 |
+
)
|
202 |
+
|
203 |
def generate_response(user_input, session_id=None):
|
204 |
"""
|
205 |
Generate a response using the agent and tools
|
|
|
218 |
if not session_id:
|
219 |
session_id = get_session_id()
|
220 |
print(f'Generated new session ID: {session_id}')
|
221 |
+
|
222 |
+
# Initialize memory with Zep history
|
223 |
+
memory = initialize_memory_from_zep(session_id)
|
224 |
+
|
225 |
+
# Create an agent executor with memory for this session
|
226 |
+
session_agent_executor = AgentExecutor(
|
227 |
+
agent=agent,
|
228 |
+
tools=tools,
|
229 |
+
verbose=True,
|
230 |
+
memory=memory, # Use the memory we initialized
|
231 |
+
handle_parsing_errors=True,
|
232 |
+
max_iterations=5
|
233 |
+
)
|
234 |
+
|
235 |
# Add retry logic
|
236 |
max_retries = 3
|
237 |
for attempt in range(max_retries):
|
238 |
try:
|
239 |
+
print('Invoking session agent executor...')
|
240 |
+
# The agent will now have access to the loaded history
|
241 |
+
response = session_agent_executor.invoke({"input": user_input})
|
|
|
|
|
|
|
242 |
|
243 |
# Extract the output and format it for Streamlit
|
244 |
if isinstance(response, dict):
|
z_utils/zep_test.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
import json
|
7 |
from dotenv import load_dotenv
|
8 |
from zep_cloud.client import Zep
|
|
|
9 |
|
10 |
# Load environment variables from .env file
|
11 |
load_dotenv()
|
@@ -41,6 +42,34 @@ def retrieve_chat_history(session_id):
|
|
41 |
print(f"Error retrieving chat history: {e}")
|
42 |
return None
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def main():
|
45 |
print(f"Retrieving chat history for session ID: {SESSION_ID}")
|
46 |
|
|
|
6 |
import json
|
7 |
from dotenv import load_dotenv
|
8 |
from zep_cloud.client import Zep
|
9 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
10 |
|
11 |
# Load environment variables from .env file
|
12 |
load_dotenv()
|
|
|
42 |
print(f"Error retrieving chat history: {e}")
|
43 |
return None
|
44 |
|
45 |
+
def get_zep_history(session_id):
|
46 |
+
"""
|
47 |
+
Retrieve chat history directly from Zep using the client.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
session_id (str): The session ID to retrieve history for
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
list: Formatted messages for LangChain
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
zep = Zep(api_key=os.environ.get("ZEP_API_KEY"))
|
57 |
+
memory = zep.memory.get(session_id=session_id)
|
58 |
+
|
59 |
+
# Convert Zep messages to LangChain format
|
60 |
+
formatted_messages = []
|
61 |
+
if memory and memory.messages:
|
62 |
+
for msg in memory.messages:
|
63 |
+
if msg.role_type == "user":
|
64 |
+
formatted_messages.append(HumanMessage(content=msg.content))
|
65 |
+
elif msg.role_type == "assistant":
|
66 |
+
formatted_messages.append(AIMessage(content=msg.content))
|
67 |
+
|
68 |
+
return formatted_messages
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error retrieving Zep history: {e}")
|
71 |
+
return []
|
72 |
+
|
73 |
def main():
|
74 |
print(f"Retrieving chat history for session ID: {SESSION_ID}")
|
75 |
|