Spaces:
Runtime error
Runtime error
from typing import List, Optional, Dict, Any | |
from phi.tools import Toolkit | |
from phi.utils.log import logger | |
try: | |
import simplejson as json | |
except ImportError: | |
raise ImportError("`simplejson` not installed") | |
try: | |
from sqlalchemy import create_engine, Engine, Row | |
from sqlalchemy.orm import Session, sessionmaker | |
from sqlalchemy.inspection import inspect | |
from sqlalchemy.sql.expression import text | |
except ImportError: | |
raise ImportError("`sqlalchemy` not installed") | |
class SQLTools(Toolkit): | |
def __init__( | |
self, | |
db_url: Optional[str] = None, | |
db_engine: Optional[Engine] = None, | |
user: Optional[str] = None, | |
password: Optional[str] = None, | |
host: Optional[str] = None, | |
port: Optional[int] = None, | |
schema: Optional[str] = None, | |
dialect: Optional[str] = None, | |
tables: Optional[Dict[str, Any]] = None, | |
list_tables: bool = True, | |
describe_table: bool = True, | |
run_sql_query: bool = True, | |
): | |
super().__init__(name="sql_tools") | |
# Get the database engine | |
_engine: Optional[Engine] = db_engine | |
if _engine is None and db_url is not None: | |
_engine = create_engine(db_url) | |
elif user and password and host and port and dialect: | |
if schema is not None: | |
_engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}/{schema}") | |
else: | |
_engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}") | |
if _engine is None: | |
raise ValueError("Could not build the database connection") | |
# Database connection | |
self.db_engine: Engine = _engine | |
self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) | |
# Tables this toolkit can access | |
self.tables: Optional[Dict[str, Any]] = tables | |
# Register functions in the toolkit | |
if list_tables: | |
self.register(self.list_tables) | |
if describe_table: | |
self.register(self.describe_table) | |
if run_sql_query: | |
self.register(self.run_sql_query) | |
def list_tables(self) -> str: | |
"""Use this function to get a list of table names in the database. | |
Returns: | |
str: list of tables in the database. | |
""" | |
if self.tables is not None: | |
return json.dumps(self.tables) | |
try: | |
table_names = inspect(self.db_engine).get_table_names() | |
logger.debug(f"table_names: {table_names}") | |
return json.dumps(table_names) | |
except Exception as e: | |
logger.error(f"Error getting tables: {e}") | |
return f"Error getting tables: {e}" | |
def describe_table(self, table_name: str) -> str: | |
"""Use this function to describe a table. | |
Args: | |
table_name (str): The name of the table to get the schema for. | |
Returns: | |
str: schema of a table | |
""" | |
try: | |
table_names = inspect(self.db_engine) | |
table_schema = table_names.get_columns(table_name) | |
return json.dumps([str(column) for column in table_schema]) | |
except Exception as e: | |
logger.error(f"Error getting table schema: {e}") | |
return f"Error getting table schema: {e}" | |
def run_sql_query(self, query: str, limit: Optional[int] = 10) -> str: | |
"""Use this function to run a SQL query and return the result. | |
Args: | |
query (str): The query to run. | |
limit (int, optional): The number of rows to return. Defaults to 10. Use `None` to show all results. | |
Returns: | |
str: Result of the SQL query. | |
Notes: | |
- The result may be empty if the query does not return any data. | |
""" | |
try: | |
return json.dumps(self.run_sql(sql=query, limit=limit)) | |
except Exception as e: | |
logger.error(f"Error running query: {e}") | |
return f"Error running query: {e}" | |
def run_sql(self, sql: str, limit: Optional[int] = None) -> List[dict]: | |
"""Internal function to run a sql query. | |
Args: | |
sql (str): The sql query to run. | |
limit (int, optional): The number of rows to return. Defaults to None. | |
Returns: | |
List[dict]: The result of the query. | |
""" | |
logger.debug(f"Running sql |\n{sql}") | |
result = None | |
with self.Session() as sess, sess.begin(): | |
if limit: | |
result = sess.execute(text(sql)).fetchmany(limit) | |
else: | |
result = sess.execute(text(sql)).fetchall() | |
logger.debug(f"SQL result: {result}") | |
if result is None: | |
return [] | |
elif isinstance(result, list): | |
return [row._asdict() for row in result] | |
elif isinstance(result, Row): | |
return [result._asdict()] | |
else: | |
logger.debug(f"SQL result type: {type(result)}") | |
return [] | |