# implementations/async_memory.py from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from app.settings import DatabaseSettings, MemorySettings from app.memory.memory import ConversationMemoryInterface from app.utils.token_counter import SimpleTokenCounter, TikTokenCounter from app.memory.models.base import Base from app.memory.models.message import Message from app.memory.models.user import User from typing import List, Dict, Optional from datetime import datetime from zoneinfo import ZoneInfo from sqlalchemy.future import select class AsyncPostgresConversationMemory(ConversationMemoryInterface): def __init__(self, db_settings: DatabaseSettings, memory_settings: MemorySettings): self.engine = create_async_engine( db_settings.url, pool_size=db_settings.pool_size, max_overflow=db_settings.max_overflow, pool_timeout=db_settings.pool_timeout ) self.async_session = sessionmaker( self.engine, class_=AsyncSession, expire_on_commit=False ) self.token_limit = memory_settings.token_limit if memory_settings.token_counter == "tiktoken": self.token_counter = TikTokenCounter(memory_settings.model_name) else: self.token_counter = SimpleTokenCounter() async def initialize(self): """Initialize the database by creating all tables.""" async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) # In your async_memory.py async def add_message(self, username: str, role: str, message: str, timestamp: Optional[datetime] = None) -> None: from app.memory.models.user import User # Import here to avoid circular dependencies async with self.async_session() as session: # Look up the user by username result = await session.execute(select(User).filter_by(username=username)) user = result.scalars().first() if user is None: raise ValueError(f"User with username '{username}' not found") if timestamp is None: timestamp = datetime.now(ZoneInfo("Asia/Jakarta")) # Create the message using the found user's id msg = Message(user_id=user.id, role=role, message=message, timestamp=timestamp) session.add(msg) await session.commit() await self.trim_memory_if_needed(session) async def get_all_history(self) -> List[Dict]: async with self.async_session() as session: result = await session.execute( select(Message).order_by(Message.timestamp) ) messages = result.scalars().all() return [{"role": msg.role, "content": msg.message} for msg in messages] async def get_history( self, username: Optional[str] = None, token_limit: Optional[int] = None, last_n: Optional[int] = None ) -> List[Dict]: async with self.async_session() as session: # Build the base query query = select(Message).order_by(Message.timestamp) if username is not None: # Join with User table and filter by username query = query.join(User).filter(User.username == username) result = await session.execute(query) messages = result.scalars().all() # Accumulate messages in reverse (latest first) selected = [] total_tokens = 0 for msg in reversed(messages): tokens = self.token_counter.count_tokens(msg.message) # If token_limit is specified and no message has been added yet, # force-add the last message even if it exceeds token_limit. if token_limit is not None and len(selected) == 0 and tokens > token_limit: selected.append(msg) total_tokens = tokens continue # Otherwise, check if adding this message would exceed the token limit. if token_limit is not None and total_tokens + tokens > token_limit: break selected.append(msg) total_tokens += tokens # Stop if we've reached the maximum number of messages. if last_n is not None and len(selected) >= last_n: break # Reverse to return in chronological order selected.reverse() return [{"role": msg.role, "parts": msg.message} for msg in selected] async def clear_memory(self) -> None: async with self.async_session() as session: await session.execute(select(Message).delete()) await session.commit() async def get_total_tokens(self) -> int: async with self.async_session() as session: result = await session.execute(select(Message)) messages = result.scalars().all() return sum(self.token_counter.count_tokens(msg.message) for msg in messages) async def trim_memory_if_needed(self, session: AsyncSession) -> None: result = await session.execute(select(Message).order_by(Message.timestamp)) messages = result.scalars().all() total_tokens = sum(self.token_counter.count_tokens(msg.message) for msg in messages) while total_tokens > self.token_limit and messages: oldest = messages.pop(0) total_tokens -= self.token_counter.count_tokens(oldest.message) await session.delete(oldest) await session.commit()