Spaces:
Build error
Build error
# implementations/async_memory.py | |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
from sqlalchemy.orm import sessionmaker | |
from app.settings import DatabaseSettings, MemorySettings | |
from app.memory.memory import ConversationMemoryInterface | |
from app.utils.token_counter import SimpleTokenCounter, TikTokenCounter | |
from app.memory.models.base import Base | |
from app.memory.models.message import Message | |
from app.memory.models.user import User | |
from typing import List, Dict, Optional | |
from datetime import datetime | |
from zoneinfo import ZoneInfo | |
from sqlalchemy.future import select | |
class AsyncPostgresConversationMemory(ConversationMemoryInterface): | |
def __init__(self, db_settings: DatabaseSettings, memory_settings: MemorySettings): | |
self.engine = create_async_engine( | |
db_settings.url, | |
pool_size=db_settings.pool_size, | |
max_overflow=db_settings.max_overflow, | |
pool_timeout=db_settings.pool_timeout | |
) | |
self.async_session = sessionmaker( | |
self.engine, class_=AsyncSession, expire_on_commit=False | |
) | |
self.token_limit = memory_settings.token_limit | |
if memory_settings.token_counter == "tiktoken": | |
self.token_counter = TikTokenCounter(memory_settings.model_name) | |
else: | |
self.token_counter = SimpleTokenCounter() | |
async def initialize(self): | |
"""Initialize the database by creating all tables.""" | |
async with self.engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
# In your async_memory.py | |
async def add_message(self, username: str, role: str, message: str, timestamp: Optional[datetime] = None) -> None: | |
from app.memory.models.user import User # Import here to avoid circular dependencies | |
async with self.async_session() as session: | |
# Look up the user by username | |
result = await session.execute(select(User).filter_by(username=username)) | |
user = result.scalars().first() | |
if user is None: | |
raise ValueError(f"User with username '{username}' not found") | |
if timestamp is None: | |
timestamp = datetime.now(ZoneInfo("Asia/Jakarta")) | |
# Create the message using the found user's id | |
msg = Message(user_id=user.id, role=role, message=message, timestamp=timestamp) | |
session.add(msg) | |
await session.commit() | |
await self.trim_memory_if_needed(session) | |
async def get_all_history(self) -> List[Dict]: | |
async with self.async_session() as session: | |
result = await session.execute( | |
select(Message).order_by(Message.timestamp) | |
) | |
messages = result.scalars().all() | |
return [{"role": msg.role, "content": msg.message} for msg in messages] | |
async def get_history( | |
self, | |
username: Optional[str] = None, | |
token_limit: Optional[int] = None, | |
last_n: Optional[int] = None | |
) -> List[Dict]: | |
async with self.async_session() as session: | |
# Build the base query | |
query = select(Message).order_by(Message.timestamp) | |
if username is not None: | |
# Join with User table and filter by username | |
query = query.join(User).filter(User.username == username) | |
result = await session.execute(query) | |
messages = result.scalars().all() | |
# Accumulate messages in reverse (latest first) | |
selected = [] | |
total_tokens = 0 | |
for msg in reversed(messages): | |
tokens = self.token_counter.count_tokens(msg.message) | |
# If token_limit is specified and no message has been added yet, | |
# force-add the last message even if it exceeds token_limit. | |
if token_limit is not None and len(selected) == 0 and tokens > token_limit: | |
selected.append(msg) | |
total_tokens = tokens | |
continue | |
# Otherwise, check if adding this message would exceed the token limit. | |
if token_limit is not None and total_tokens + tokens > token_limit: | |
break | |
selected.append(msg) | |
total_tokens += tokens | |
# Stop if we've reached the maximum number of messages. | |
if last_n is not None and len(selected) >= last_n: | |
break | |
# Reverse to return in chronological order | |
selected.reverse() | |
return [{"role": msg.role, "parts": msg.message} for msg in selected] | |
async def clear_memory(self) -> None: | |
async with self.async_session() as session: | |
await session.execute(select(Message).delete()) | |
await session.commit() | |
async def get_total_tokens(self) -> int: | |
async with self.async_session() as session: | |
result = await session.execute(select(Message)) | |
messages = result.scalars().all() | |
return sum(self.token_counter.count_tokens(msg.message) for msg in messages) | |
async def trim_memory_if_needed(self, session: AsyncSession) -> None: | |
result = await session.execute(select(Message).order_by(Message.timestamp)) | |
messages = result.scalars().all() | |
total_tokens = sum(self.token_counter.count_tokens(msg.message) for msg in messages) | |
while total_tokens > self.token_limit and messages: | |
oldest = messages.pop(0) | |
total_tokens -= self.token_counter.count_tokens(oldest.message) | |
await session.delete(oldest) | |
await session.commit() |