Spaces:
Running
Running
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver | |
from langgraph.store.memory import InMemoryStore | |
import aiosqlite | |
from types import TracebackType | |
from typing import Optional, Type | |
from agent.utils.state import AgentState | |
from agent.utils.nodes import ( | |
call_model, | |
tool_node, | |
write_memory, | |
should_continue | |
) | |
def create_graph_builder(): | |
"""Create a base graph builder with nodes and edges configured.""" | |
builder = StateGraph(AgentState) | |
# Add nodes | |
builder.add_node("agent", call_model) | |
builder.add_node("action", tool_node) | |
builder.add_node("write_memory", write_memory) | |
# Set entry point | |
builder.set_entry_point("agent") | |
# Add conditional edges from agent | |
builder.add_conditional_edges( | |
"agent", | |
should_continue, | |
{ | |
"action": "action", | |
"write_memory": "write_memory", | |
} | |
) | |
# Connect action back to agent | |
builder.add_edge("action", "agent") | |
builder.add_edge("write_memory", END) | |
return builder | |
def create_agent_graph_without_memory(): | |
"""Create an agent graph without memory persistence.""" | |
builder = create_graph_builder() | |
return builder.compile() | |
class SQLiteCheckpointer: | |
"""Context manager for SQLite checkpointing.""" | |
def __init__(self, db_path: str): | |
self.db_path = db_path | |
self.saver: Optional[AsyncSqliteSaver] = None | |
async def __aenter__(self) -> AsyncSqliteSaver: | |
"""Initialize and return the AsyncSqliteSaver.""" | |
conn = await aiosqlite.connect(self.db_path) | |
self.saver = AsyncSqliteSaver(conn) | |
return self.saver | |
async def __aexit__( | |
self, | |
exc_type: Optional[Type[BaseException]], | |
exc_val: Optional[BaseException], | |
exc_tb: Optional[TracebackType], | |
) -> None: | |
"""Clean up the SQLite connection.""" | |
if self.saver and hasattr(self.saver, 'conn'): | |
await self.saver.conn.close() | |
self.saver = None | |
def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer: | |
"""Create and return a SQLiteCheckpointer instance.""" | |
return SQLiteCheckpointer(db_path) | |
# Initialize store for across-thread memory | |
across_thread_memory = InMemoryStore() | |
async def create_agent_graph(checkpointer: AsyncSqliteSaver): | |
"""Create an agent graph with memory persistence.""" | |
builder = create_graph_builder() | |
# Compile with both SQLite checkpointer for within-thread memory | |
# and InMemoryStore for across-thread memory | |
graph = builder.compile( | |
checkpointer=checkpointer, | |
store=across_thread_memory | |
) | |
return graph | |
langgraph_studio_graph = create_agent_graph_without_memory() | |
# Export the graph builder functions | |
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"] |