Spaces:
Runtime error
Runtime error
File size: 5,030 Bytes
105b369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from typing import List, Optional, Dict, Any
from phi.tools import Toolkit
from phi.utils.log import logger
try:
import simplejson as json
except ImportError:
raise ImportError("`simplejson` not installed")
try:
from sqlalchemy import create_engine, Engine, Row
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.expression import text
except ImportError:
raise ImportError("`sqlalchemy` not installed")
class SQLTools(Toolkit):
def __init__(
self,
db_url: Optional[str] = None,
db_engine: Optional[Engine] = None,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
schema: Optional[str] = None,
dialect: Optional[str] = None,
tables: Optional[Dict[str, Any]] = None,
list_tables: bool = True,
describe_table: bool = True,
run_sql_query: bool = True,
):
super().__init__(name="sql_tools")
# Get the database engine
_engine: Optional[Engine] = db_engine
if _engine is None and db_url is not None:
_engine = create_engine(db_url)
elif user and password and host and port and dialect:
if schema is not None:
_engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}/{schema}")
else:
_engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}")
if _engine is None:
raise ValueError("Could not build the database connection")
# Database connection
self.db_engine: Engine = _engine
self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine)
# Tables this toolkit can access
self.tables: Optional[Dict[str, Any]] = tables
# Register functions in the toolkit
if list_tables:
self.register(self.list_tables)
if describe_table:
self.register(self.describe_table)
if run_sql_query:
self.register(self.run_sql_query)
def list_tables(self) -> str:
"""Use this function to get a list of table names in the database.
Returns:
str: list of tables in the database.
"""
if self.tables is not None:
return json.dumps(self.tables)
try:
table_names = inspect(self.db_engine).get_table_names()
logger.debug(f"table_names: {table_names}")
return json.dumps(table_names)
except Exception as e:
logger.error(f"Error getting tables: {e}")
return f"Error getting tables: {e}"
def describe_table(self, table_name: str) -> str:
"""Use this function to describe a table.
Args:
table_name (str): The name of the table to get the schema for.
Returns:
str: schema of a table
"""
try:
table_names = inspect(self.db_engine)
table_schema = table_names.get_columns(table_name)
return json.dumps([str(column) for column in table_schema])
except Exception as e:
logger.error(f"Error getting table schema: {e}")
return f"Error getting table schema: {e}"
def run_sql_query(self, query: str, limit: Optional[int] = 10) -> str:
"""Use this function to run a SQL query and return the result.
Args:
query (str): The query to run.
limit (int, optional): The number of rows to return. Defaults to 10. Use `None` to show all results.
Returns:
str: Result of the SQL query.
Notes:
- The result may be empty if the query does not return any data.
"""
try:
return json.dumps(self.run_sql(sql=query, limit=limit))
except Exception as e:
logger.error(f"Error running query: {e}")
return f"Error running query: {e}"
def run_sql(self, sql: str, limit: Optional[int] = None) -> List[dict]:
"""Internal function to run a sql query.
Args:
sql (str): The sql query to run.
limit (int, optional): The number of rows to return. Defaults to None.
Returns:
List[dict]: The result of the query.
"""
logger.debug(f"Running sql |\n{sql}")
result = None
with self.Session() as sess, sess.begin():
if limit:
result = sess.execute(text(sql)).fetchmany(limit)
else:
result = sess.execute(text(sql)).fetchall()
logger.debug(f"SQL result: {result}")
if result is None:
return []
elif isinstance(result, list):
return [row._asdict() for row in result]
elif isinstance(result, Row):
return [result._asdict()]
else:
logger.debug(f"SQL result type: {type(result)}")
return []
|