File size: 2,930 Bytes
241f177
28f3481
241f177
28f3481
 
 
 
 
241f177
 
 
 
 
 
28f3481
 
 
 
241f177
28f3481
 
 
 
241f177
28f3481
241f177
28f3481
241f177
 
 
28f3481
 
 
241f177
 
 
 
28f3481
241f177
 
28f3481
e3030c3
28f3481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241f177
 
 
28f3481
241f177
28f3481
241f177
 
 
 
 
 
28f3481
 
241f177
 
28f3481
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.store.memory import InMemoryStore
import aiosqlite
from types import TracebackType
from typing import Optional, Type

from agent.utils.state import AgentState
from agent.utils.nodes import (
    call_model,
    tool_node,
    write_memory,
    should_continue
)

def create_graph_builder():
    """Create a base graph builder with nodes and edges configured."""
    builder = StateGraph(AgentState)
    

    # Add nodes
    builder.add_node("agent", call_model)
    builder.add_node("action", tool_node)
    builder.add_node("write_memory", write_memory)

    # Set entry point
    builder.set_entry_point("agent")


    # Add conditional edges from agent
    builder.add_conditional_edges(
        "agent",
        should_continue,
        {
            "action": "action",
            "write_memory": "write_memory",
        }
    )
    
    # Connect action back to agent
    builder.add_edge("action", "agent")
    builder.add_edge("write_memory", END)
    
    return builder

def create_agent_graph_without_memory():
    """Create an agent graph without memory persistence."""
    builder = create_graph_builder()
    return builder.compile()

class SQLiteCheckpointer:
    """Context manager for SQLite checkpointing."""
    
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.saver: Optional[AsyncSqliteSaver] = None
    
    async def __aenter__(self) -> AsyncSqliteSaver:
        """Initialize and return the AsyncSqliteSaver."""
        conn = await aiosqlite.connect(self.db_path)
        self.saver = AsyncSqliteSaver(conn)
        return self.saver
    
    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        """Clean up the SQLite connection."""
        if self.saver and hasattr(self.saver, 'conn'):
            await self.saver.conn.close()
            self.saver = None

def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer:
    """Create and return a SQLiteCheckpointer instance."""
    return SQLiteCheckpointer(db_path)

# Initialize store for across-thread memory
across_thread_memory = InMemoryStore()

async def create_agent_graph(checkpointer: AsyncSqliteSaver):
    """Create an agent graph with memory persistence."""
    builder = create_graph_builder()
    # Compile with both SQLite checkpointer for within-thread memory
    # and InMemoryStore for across-thread memory
    graph = builder.compile(
        checkpointer=checkpointer,
        store=across_thread_memory
    )
    return graph

langgraph_studio_graph = create_agent_graph_without_memory()

# Export the graph builder functions
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]