File size: 5,801 Bytes
7b2511b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# implementations/async_memory.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker

from app.settings import DatabaseSettings, MemorySettings
from app.memory.memory import ConversationMemoryInterface
from app.utils.token_counter import SimpleTokenCounter, TikTokenCounter
from app.memory.models.base import Base
from app.memory.models.message import Message
from app.memory.models.user import User
from typing import List, Dict, Optional
from datetime import datetime
from zoneinfo import ZoneInfo
from sqlalchemy.future import select

class AsyncPostgresConversationMemory(ConversationMemoryInterface):
    def __init__(self, db_settings: DatabaseSettings, memory_settings: MemorySettings):
        self.engine = create_async_engine(
            db_settings.url,
            pool_size=db_settings.pool_size,
            max_overflow=db_settings.max_overflow,
            pool_timeout=db_settings.pool_timeout
        )

        self.async_session = sessionmaker(
            self.engine, class_=AsyncSession, expire_on_commit=False
        )
        self.token_limit = memory_settings.token_limit
        
        
        if memory_settings.token_counter == "tiktoken":
            self.token_counter = TikTokenCounter(memory_settings.model_name)
        else:
            self.token_counter = SimpleTokenCounter()

    async def initialize(self):
        """Initialize the database by creating all tables."""
        async with self.engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)
    
    # In your async_memory.py
    async def add_message(self, username: str, role: str, message: str, timestamp: Optional[datetime] = None) -> None:
        from app.memory.models.user import User  # Import here to avoid circular dependencies
        async with self.async_session() as session:
            # Look up the user by username
            result = await session.execute(select(User).filter_by(username=username))
            user = result.scalars().first()
            if user is None:
                raise ValueError(f"User with username '{username}' not found")
            
            if timestamp is None:
                timestamp = datetime.now(ZoneInfo("Asia/Jakarta"))
            
            # Create the message using the found user's id
            msg = Message(user_id=user.id, role=role, message=message, timestamp=timestamp)
            session.add(msg)
            await session.commit()
            await self.trim_memory_if_needed(session)



    async def get_all_history(self) -> List[Dict]:
        async with self.async_session() as session:
            result = await session.execute(
                select(Message).order_by(Message.timestamp)
            )
            messages = result.scalars().all()
            return [{"role": msg.role, "content": msg.message} for msg in messages]
    
    async def get_history(

        self, 

        username: Optional[str] = None, 

        token_limit: Optional[int] = None, 

        last_n: Optional[int] = None

    ) -> List[Dict]:
        async with self.async_session() as session:
            # Build the base query
            query = select(Message).order_by(Message.timestamp)
            if username is not None:
                # Join with User table and filter by username
                query = query.join(User).filter(User.username == username)
            result = await session.execute(query)
            messages = result.scalars().all()

        # Accumulate messages in reverse (latest first)
        selected = []
        total_tokens = 0
        for msg in reversed(messages):
            tokens = self.token_counter.count_tokens(msg.message)
            # If token_limit is specified and no message has been added yet,
            # force-add the last message even if it exceeds token_limit.
            if token_limit is not None and len(selected) == 0 and tokens > token_limit:
                selected.append(msg)
                total_tokens = tokens
                continue
            # Otherwise, check if adding this message would exceed the token limit.
            if token_limit is not None and total_tokens + tokens > token_limit:
                break
            selected.append(msg)
            total_tokens += tokens
            # Stop if we've reached the maximum number of messages.
            if last_n is not None and len(selected) >= last_n:
                break

        # Reverse to return in chronological order
        selected.reverse()
        return [{"role": msg.role, "parts": msg.message} for msg in selected]


    async def clear_memory(self) -> None:
        async with self.async_session() as session:
            await session.execute(select(Message).delete())
            await session.commit()

    async def get_total_tokens(self) -> int:
        async with self.async_session() as session:
            result = await session.execute(select(Message))
            messages = result.scalars().all()
            return sum(self.token_counter.count_tokens(msg.message) for msg in messages)

    async def trim_memory_if_needed(self, session: AsyncSession) -> None:
        result = await session.execute(select(Message).order_by(Message.timestamp))
        messages = result.scalars().all()
        total_tokens = sum(self.token_counter.count_tokens(msg.message) for msg in messages)
        
        while total_tokens > self.token_limit and messages:
            oldest = messages.pop(0)
            total_tokens -= self.token_counter.count_tokens(oldest.message)
            await session.delete(oldest)
        
        await session.commit()