from typing import Optional, Any, List try: from sqlalchemy.dialects import sqlite from sqlalchemy.engine import create_engine, Engine from sqlalchemy.engine.row import Row from sqlalchemy.inspection import inspect from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.schema import MetaData, Table, Column from sqlalchemy.sql.expression import select from sqlalchemy.types import String except ImportError: raise ImportError("`sqlalchemy` not installed") from sqlite3 import OperationalError from phi.assistant.run import AssistantRun from phi.storage.assistant.base import AssistantStorage from phi.utils.dttm import current_datetime from phi.utils.log import logger class SqlAssistantStorage(AssistantStorage): def __init__( self, table_name: str, db_url: Optional[str] = None, db_file: Optional[str] = None, db_engine: Optional[Engine] = None, ): """ This class provides assistant storage using a sqlite database. The following order is used to determine the database connection: 1. Use the db_engine if provided 2. Use the db_url 3. Use the db_file 4. Create a new in-memory database :param table_name: The name of the table to store assistant runs. :param db_url: The database URL to connect to. :param db_file: The database file to connect to. :param db_engine: The database engine to use. """ _engine: Optional[Engine] = db_engine if _engine is None and db_url is not None: _engine = create_engine(db_url) elif _engine is None and db_file is not None: _engine = create_engine(f"sqlite:///{db_file}") else: _engine = create_engine("sqlite://") if _engine is None: raise ValueError("Must provide either db_url, db_file or db_engine") # Database attributes self.table_name: str = table_name self.db_url: Optional[str] = db_url self.db_engine: Engine = _engine self.metadata: MetaData = MetaData() # Database session self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) # Database table for storage self.table: Table = self.get_table() def get_table(self) -> Table: return Table( self.table_name, self.metadata, # Database ID/Primary key for this run Column("run_id", String, primary_key=True), # Assistant name Column("name", String), # Run name Column("run_name", String), # ID of the user participating in this run Column("user_id", String), # -*- LLM data (name, model, etc.) Column("llm", sqlite.JSON), # -*- Assistant memory Column("memory", sqlite.JSON), # Metadata associated with this assistant Column("assistant_data", sqlite.JSON), # Metadata associated with this run Column("run_data", sqlite.JSON), # Metadata associated the user participating in this run Column("user_data", sqlite.JSON), # Metadata associated with the assistant tasks Column("task_data", sqlite.JSON), # The timestamp of when this run was created. Column("created_at", sqlite.DATETIME, default=current_datetime()), # The timestamp of when this run was last updated. Column("updated_at", sqlite.DATETIME, onupdate=current_datetime()), extend_existing=True, sqlite_autoincrement=True, ) def table_exists(self) -> bool: logger.debug(f"Checking if table exists: {self.table.name}") try: return inspect(self.db_engine).has_table(self.table.name) except Exception as e: logger.error(e) return False def create(self) -> None: if not self.table_exists(): logger.debug(f"Creating table: {self.table.name}") self.table.create(self.db_engine) def _read(self, session: Session, run_id: str) -> Optional[Row[Any]]: stmt = select(self.table).where(self.table.c.run_id == run_id) try: return session.execute(stmt).first() except OperationalError: # Create table if it does not exist self.create() except Exception as e: logger.warning(e) return None def read(self, run_id: str) -> Optional[AssistantRun]: with self.Session() as sess: existing_row: Optional[Row[Any]] = self._read(session=sess, run_id=run_id) return AssistantRun.model_validate(existing_row) if existing_row is not None else None def get_all_run_ids(self, user_id: Optional[str] = None) -> List[str]: run_ids: List[str] = [] try: with self.Session() as sess: # get all run_ids for this user stmt = select(self.table) if user_id is not None: stmt = stmt.where(self.table.c.user_id == user_id) # order by created_at desc stmt = stmt.order_by(self.table.c.created_at.desc()) # execute query rows = sess.execute(stmt).fetchall() for row in rows: if row is not None and row.run_id is not None: run_ids.append(row.run_id) except OperationalError: logger.debug(f"Table does not exist: {self.table.name}") pass return run_ids def get_all_runs(self, user_id: Optional[str] = None) -> List[AssistantRun]: conversations: List[AssistantRun] = [] try: with self.Session() as sess: # get all runs for this user stmt = select(self.table) if user_id is not None: stmt = stmt.where(self.table.c.user_id == user_id) # order by created_at desc stmt = stmt.order_by(self.table.c.created_at.desc()) # execute query rows = sess.execute(stmt).fetchall() for row in rows: if row.run_id is not None: conversations.append(AssistantRun.model_validate(row)) except OperationalError: logger.debug(f"Table does not exist: {self.table.name}") pass return conversations def upsert(self, row: AssistantRun) -> Optional[AssistantRun]: """ Create a new assistant run if it does not exist, otherwise update the existing conversation. """ with self.Session() as sess: # Create an insert statement stmt = sqlite.insert(self.table).values( run_id=row.run_id, name=row.name, run_name=row.run_name, user_id=row.user_id, llm=row.llm, memory=row.memory, assistant_data=row.assistant_data, run_data=row.run_data, user_data=row.user_data, task_data=row.task_data, ) # Define the upsert if the run_id already exists # See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert stmt = stmt.on_conflict_do_update( index_elements=["run_id"], set_=dict( name=row.name, run_name=row.run_name, user_id=row.user_id, llm=row.llm, memory=row.memory, assistant_data=row.assistant_data, run_data=row.run_data, user_data=row.user_data, task_data=row.task_data, ), # The updated value for each column ) try: sess.execute(stmt) except OperationalError: # Create table if it does not exist self.create() sess.execute(stmt) return self.read(run_id=row.run_id) def delete(self) -> None: if self.table_exists(): logger.debug(f"Deleting table: {self.table_name}") self.table.drop(self.db_engine)