ric9176's picture
feat: fix and test chainlit interface for cross thread memory
d6e8585
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.store.memory import InMemoryStore
import aiosqlite
from types import TracebackType
from typing import Optional, Type
from agent.utils.state import AgentState
from agent.utils.nodes import (
call_model,
tool_node,
write_memory,
should_continue
)
def create_graph_builder():
"""Create a base graph builder with nodes and edges configured."""
builder = StateGraph(AgentState)
# Add nodes
builder.add_node("agent", call_model)
builder.add_node("action", tool_node)
builder.add_node("write_memory", write_memory)
# Set entry point
builder.set_entry_point("agent")
# Add conditional edges from agent
builder.add_conditional_edges(
"agent",
should_continue,
{
"action": "action",
"write_memory": "write_memory",
}
)
# Connect action back to agent
builder.add_edge("action", "agent")
builder.add_edge("write_memory", END)
return builder
def create_agent_graph_without_memory():
"""Create an agent graph without memory persistence."""
builder = create_graph_builder()
return builder.compile()
class SQLiteCheckpointer:
"""Context manager for SQLite checkpointing."""
def __init__(self, db_path: str):
self.db_path = db_path
self.saver: Optional[AsyncSqliteSaver] = None
async def __aenter__(self) -> AsyncSqliteSaver:
"""Initialize and return the AsyncSqliteSaver."""
conn = await aiosqlite.connect(self.db_path)
self.saver = AsyncSqliteSaver(conn)
return self.saver
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Clean up the SQLite connection."""
if self.saver and hasattr(self.saver, 'conn'):
await self.saver.conn.close()
self.saver = None
def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer:
"""Create and return a SQLiteCheckpointer instance."""
return SQLiteCheckpointer(db_path)
# Initialize store for across-thread memory
across_thread_memory = InMemoryStore()
async def create_agent_graph(checkpointer: AsyncSqliteSaver):
"""Create an agent graph with memory persistence."""
builder = create_graph_builder()
# Compile with both SQLite checkpointer for within-thread memory
# and InMemoryStore for across-thread memory
graph = builder.compile(
checkpointer=checkpointer,
store=across_thread_memory
)
return graph
langgraph_studio_graph = create_agent_graph_without_memory()
# Export the graph builder functions
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]