AmmarFahmy
adding all files
105b369
from typing import List, Optional, Dict, Any
from phi.tools import Toolkit
from phi.utils.log import logger
try:
import simplejson as json
except ImportError:
raise ImportError("`simplejson` not installed")
try:
from sqlalchemy import create_engine, Engine, Row
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.expression import text
except ImportError:
raise ImportError("`sqlalchemy` not installed")
class SQLTools(Toolkit):
def __init__(
self,
db_url: Optional[str] = None,
db_engine: Optional[Engine] = None,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
schema: Optional[str] = None,
dialect: Optional[str] = None,
tables: Optional[Dict[str, Any]] = None,
list_tables: bool = True,
describe_table: bool = True,
run_sql_query: bool = True,
):
super().__init__(name="sql_tools")
# Get the database engine
_engine: Optional[Engine] = db_engine
if _engine is None and db_url is not None:
_engine = create_engine(db_url)
elif user and password and host and port and dialect:
if schema is not None:
_engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}/{schema}")
else:
_engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}")
if _engine is None:
raise ValueError("Could not build the database connection")
# Database connection
self.db_engine: Engine = _engine
self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine)
# Tables this toolkit can access
self.tables: Optional[Dict[str, Any]] = tables
# Register functions in the toolkit
if list_tables:
self.register(self.list_tables)
if describe_table:
self.register(self.describe_table)
if run_sql_query:
self.register(self.run_sql_query)
def list_tables(self) -> str:
"""Use this function to get a list of table names in the database.
Returns:
str: list of tables in the database.
"""
if self.tables is not None:
return json.dumps(self.tables)
try:
table_names = inspect(self.db_engine).get_table_names()
logger.debug(f"table_names: {table_names}")
return json.dumps(table_names)
except Exception as e:
logger.error(f"Error getting tables: {e}")
return f"Error getting tables: {e}"
def describe_table(self, table_name: str) -> str:
"""Use this function to describe a table.
Args:
table_name (str): The name of the table to get the schema for.
Returns:
str: schema of a table
"""
try:
table_names = inspect(self.db_engine)
table_schema = table_names.get_columns(table_name)
return json.dumps([str(column) for column in table_schema])
except Exception as e:
logger.error(f"Error getting table schema: {e}")
return f"Error getting table schema: {e}"
def run_sql_query(self, query: str, limit: Optional[int] = 10) -> str:
"""Use this function to run a SQL query and return the result.
Args:
query (str): The query to run.
limit (int, optional): The number of rows to return. Defaults to 10. Use `None` to show all results.
Returns:
str: Result of the SQL query.
Notes:
- The result may be empty if the query does not return any data.
"""
try:
return json.dumps(self.run_sql(sql=query, limit=limit))
except Exception as e:
logger.error(f"Error running query: {e}")
return f"Error running query: {e}"
def run_sql(self, sql: str, limit: Optional[int] = None) -> List[dict]:
"""Internal function to run a sql query.
Args:
sql (str): The sql query to run.
limit (int, optional): The number of rows to return. Defaults to None.
Returns:
List[dict]: The result of the query.
"""
logger.debug(f"Running sql |\n{sql}")
result = None
with self.Session() as sess, sess.begin():
if limit:
result = sess.execute(text(sql)).fetchmany(limit)
else:
result = sess.execute(text(sql)).fetchall()
logger.debug(f"SQL result: {result}")
if result is None:
return []
elif isinstance(result, list):
return [row._asdict() for row in result]
elif isinstance(result, Row):
return [result._asdict()]
else:
logger.debug(f"SQL result type: {type(result)}")
return []