Spaces:
Runtime error
Runtime error
from typing import Optional, Any, List | |
try: | |
from sqlalchemy.dialects import sqlite | |
from sqlalchemy.engine import create_engine, Engine | |
from sqlalchemy.engine.row import Row | |
from sqlalchemy.inspection import inspect | |
from sqlalchemy.orm import Session, sessionmaker | |
from sqlalchemy.schema import MetaData, Table, Column | |
from sqlalchemy.sql.expression import select | |
from sqlalchemy.types import String | |
except ImportError: | |
raise ImportError("`sqlalchemy` not installed") | |
from sqlite3 import OperationalError | |
from phi.assistant.run import AssistantRun | |
from phi.storage.assistant.base import AssistantStorage | |
from phi.utils.dttm import current_datetime | |
from phi.utils.log import logger | |
class SqlAssistantStorage(AssistantStorage): | |
def __init__( | |
self, | |
table_name: str, | |
db_url: Optional[str] = None, | |
db_file: Optional[str] = None, | |
db_engine: Optional[Engine] = None, | |
): | |
""" | |
This class provides assistant storage using a sqlite database. | |
The following order is used to determine the database connection: | |
1. Use the db_engine if provided | |
2. Use the db_url | |
3. Use the db_file | |
4. Create a new in-memory database | |
:param table_name: The name of the table to store assistant runs. | |
:param db_url: The database URL to connect to. | |
:param db_file: The database file to connect to. | |
:param db_engine: The database engine to use. | |
""" | |
_engine: Optional[Engine] = db_engine | |
if _engine is None and db_url is not None: | |
_engine = create_engine(db_url) | |
elif _engine is None and db_file is not None: | |
_engine = create_engine(f"sqlite:///{db_file}") | |
else: | |
_engine = create_engine("sqlite://") | |
if _engine is None: | |
raise ValueError("Must provide either db_url, db_file or db_engine") | |
# Database attributes | |
self.table_name: str = table_name | |
self.db_url: Optional[str] = db_url | |
self.db_engine: Engine = _engine | |
self.metadata: MetaData = MetaData() | |
# Database session | |
self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) | |
# Database table for storage | |
self.table: Table = self.get_table() | |
def get_table(self) -> Table: | |
return Table( | |
self.table_name, | |
self.metadata, | |
# Database ID/Primary key for this run | |
Column("run_id", String, primary_key=True), | |
# Assistant name | |
Column("name", String), | |
# Run name | |
Column("run_name", String), | |
# ID of the user participating in this run | |
Column("user_id", String), | |
# -*- LLM data (name, model, etc.) | |
Column("llm", sqlite.JSON), | |
# -*- Assistant memory | |
Column("memory", sqlite.JSON), | |
# Metadata associated with this assistant | |
Column("assistant_data", sqlite.JSON), | |
# Metadata associated with this run | |
Column("run_data", sqlite.JSON), | |
# Metadata associated the user participating in this run | |
Column("user_data", sqlite.JSON), | |
# Metadata associated with the assistant tasks | |
Column("task_data", sqlite.JSON), | |
# The timestamp of when this run was created. | |
Column("created_at", sqlite.DATETIME, default=current_datetime()), | |
# The timestamp of when this run was last updated. | |
Column("updated_at", sqlite.DATETIME, onupdate=current_datetime()), | |
extend_existing=True, | |
sqlite_autoincrement=True, | |
) | |
def table_exists(self) -> bool: | |
logger.debug(f"Checking if table exists: {self.table.name}") | |
try: | |
return inspect(self.db_engine).has_table(self.table.name) | |
except Exception as e: | |
logger.error(e) | |
return False | |
def create(self) -> None: | |
if not self.table_exists(): | |
logger.debug(f"Creating table: {self.table.name}") | |
self.table.create(self.db_engine) | |
def _read(self, session: Session, run_id: str) -> Optional[Row[Any]]: | |
stmt = select(self.table).where(self.table.c.run_id == run_id) | |
try: | |
return session.execute(stmt).first() | |
except OperationalError: | |
# Create table if it does not exist | |
self.create() | |
except Exception as e: | |
logger.warning(e) | |
return None | |
def read(self, run_id: str) -> Optional[AssistantRun]: | |
with self.Session() as sess: | |
existing_row: Optional[Row[Any]] = self._read(session=sess, run_id=run_id) | |
return AssistantRun.model_validate(existing_row) if existing_row is not None else None | |
def get_all_run_ids(self, user_id: Optional[str] = None) -> List[str]: | |
run_ids: List[str] = [] | |
try: | |
with self.Session() as sess: | |
# get all run_ids for this user | |
stmt = select(self.table) | |
if user_id is not None: | |
stmt = stmt.where(self.table.c.user_id == user_id) | |
# order by created_at desc | |
stmt = stmt.order_by(self.table.c.created_at.desc()) | |
# execute query | |
rows = sess.execute(stmt).fetchall() | |
for row in rows: | |
if row is not None and row.run_id is not None: | |
run_ids.append(row.run_id) | |
except OperationalError: | |
logger.debug(f"Table does not exist: {self.table.name}") | |
pass | |
return run_ids | |
def get_all_runs(self, user_id: Optional[str] = None) -> List[AssistantRun]: | |
conversations: List[AssistantRun] = [] | |
try: | |
with self.Session() as sess: | |
# get all runs for this user | |
stmt = select(self.table) | |
if user_id is not None: | |
stmt = stmt.where(self.table.c.user_id == user_id) | |
# order by created_at desc | |
stmt = stmt.order_by(self.table.c.created_at.desc()) | |
# execute query | |
rows = sess.execute(stmt).fetchall() | |
for row in rows: | |
if row.run_id is not None: | |
conversations.append(AssistantRun.model_validate(row)) | |
except OperationalError: | |
logger.debug(f"Table does not exist: {self.table.name}") | |
pass | |
return conversations | |
def upsert(self, row: AssistantRun) -> Optional[AssistantRun]: | |
""" | |
Create a new assistant run if it does not exist, otherwise update the existing conversation. | |
""" | |
with self.Session() as sess: | |
# Create an insert statement | |
stmt = sqlite.insert(self.table).values( | |
run_id=row.run_id, | |
name=row.name, | |
run_name=row.run_name, | |
user_id=row.user_id, | |
llm=row.llm, | |
memory=row.memory, | |
assistant_data=row.assistant_data, | |
run_data=row.run_data, | |
user_data=row.user_data, | |
task_data=row.task_data, | |
) | |
# Define the upsert if the run_id already exists | |
# See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert | |
stmt = stmt.on_conflict_do_update( | |
index_elements=["run_id"], | |
set_=dict( | |
name=row.name, | |
run_name=row.run_name, | |
user_id=row.user_id, | |
llm=row.llm, | |
memory=row.memory, | |
assistant_data=row.assistant_data, | |
run_data=row.run_data, | |
user_data=row.user_data, | |
task_data=row.task_data, | |
), # The updated value for each column | |
) | |
try: | |
sess.execute(stmt) | |
except OperationalError: | |
# Create table if it does not exist | |
self.create() | |
sess.execute(stmt) | |
return self.read(run_id=row.run_id) | |
def delete(self) -> None: | |
if self.table_exists(): | |
logger.debug(f"Deleting table: {self.table_name}") | |
self.table.drop(self.db_engine) | |