from typing import Optional, Tuple, List, Dict, Any from phi.tools import Toolkit from phi.utils.log import logger try: import duckdb except ImportError: raise ImportError("`duckdb` not installed. Please install using `pip install duckdb`.") class DuckDbTools(Toolkit): def __init__( self, db_path: Optional[str] = None, connection: Optional[duckdb.DuckDBPyConnection] = None, init_commands: Optional[List] = None, read_only: bool = False, config: Optional[dict] = None, run_queries: bool = True, inspect_queries: bool = False, create_tables: bool = True, summarize_tables: bool = True, export_tables: bool = False, ): super().__init__(name="duckdb_tools") self.db_path: Optional[str] = db_path self.read_only: bool = read_only self.config: Optional[dict] = config self._connection: Optional[duckdb.DuckDBPyConnection] = connection self.init_commands: Optional[List] = init_commands self.register(self.show_tables) self.register(self.describe_table) if inspect_queries: self.register(self.inspect_query) if run_queries: self.register(self.run_query) if create_tables: self.register(self.create_table_from_path) if summarize_tables: self.register(self.summarize_table) if export_tables: self.register(self.export_table_to_path) @property def connection(self) -> duckdb.DuckDBPyConnection: """ Returns the duckdb connection :return duckdb.DuckDBPyConnection: duckdb connection """ if self._connection is None: connection_kwargs: Dict[str, Any] = {} if self.db_path is not None: connection_kwargs["database"] = self.db_path if self.read_only: connection_kwargs["read_only"] = self.read_only if self.config is not None: connection_kwargs["config"] = self.config self._connection = duckdb.connect(**connection_kwargs) try: if self.init_commands is not None: for command in self.init_commands: self._connection.sql(command) except Exception as e: logger.exception(e) logger.warning("Failed to run duckdb init commands") return self._connection def show_tables(self) -> str: """Function to show tables in the database :return: List of tables in the database """ stmt = "SHOW TABLES;" tables = self.run_query(stmt) logger.debug(f"Tables: {tables}") return tables def describe_table(self, table: str) -> str: """Function to describe a table :param table: Table to describe :return: Description of the table """ stmt = f"DESCRIBE {table};" table_description = self.run_query(stmt) logger.debug(f"Table description: {table_description}") return f"{table}\n{table_description}" def inspect_query(self, query: str) -> str: """Function to inspect a query and return the query plan. Always inspect your query before running them. :param query: Query to inspect :return: Qeury plan """ stmt = f"explain {query};" explain_plan = self.run_query(stmt) logger.debug(f"Explain plan: {explain_plan}") return explain_plan def run_query(self, query: str) -> str: """Function that runs a query and returns the result. :param query: SQL query to run :return: Result of the query """ # -*- Format the SQL Query # Remove backticks formatted_sql = query.replace("`", "") # If there are multiple statements, only run the first one formatted_sql = formatted_sql.split(";")[0] try: logger.info(f"Running: {formatted_sql}") query_result = self.connection.sql(formatted_sql) result_output = "No output" if query_result is not None: try: results_as_python_objects = query_result.fetchall() result_rows = [] for row in results_as_python_objects: if len(row) == 1: result_rows.append(str(row[0])) else: result_rows.append(",".join(str(x) for x in row)) result_data = "\n".join(result_rows) result_output = ",".join(query_result.columns) + "\n" + result_data except AttributeError: result_output = str(query_result) logger.debug(f"Query result: {result_output}") return result_output except duckdb.ProgrammingError as e: return str(e) except duckdb.Error as e: return str(e) except Exception as e: return str(e) def summarize_table(self, table: str) -> str: """Function to compute a number of aggregates over a table. The function launches a query that computes a number of aggregates over all columns, including min, max, avg, std and approx_unique. :param table: Table to summarize :return: Summary of the table """ table_summary = self.run_query(f"SUMMARIZE {table};") logger.debug(f"Table description: {table_summary}") return table_summary def get_table_name_from_path(self, path: str) -> str: """Get the table name from a path :param path: Path to get the table name from :return: Table name """ import os # Get the file name from the path file_name = path.split("/")[-1] # Get the file name without extension from the path table, extension = os.path.splitext(file_name) # If the table isn't a valid SQL identifier, we'll need to use something else table = table.replace("-", "_").replace(".", "_").replace(" ", "_").replace("/", "_") return table def create_table_from_path(self, path: str, table: Optional[str] = None, replace: bool = False) -> str: """Creates a table from a path :param path: Path to load :param table: Optional table name to use :param replace: Whether to replace the table if it already exists :return: Table name created """ if table is None: table = self.get_table_name_from_path(path) logger.debug(f"Creating table {table} from {path}") create_statement = "CREATE TABLE IF NOT EXISTS" if replace: create_statement = "CREATE OR REPLACE TABLE" create_statement += f" '{table}' AS SELECT * FROM '{path}';" self.run_query(create_statement) logger.debug(f"Created table {table} from {path}") return table def export_table_to_path(self, table: str, format: Optional[str] = "PARQUET", path: Optional[str] = None) -> str: """Save a table in a desired format (default: parquet) If the path is provided, the table will be saved under that path. Eg: If path is /tmp, the table will be saved as /tmp/table.parquet Otherwise it will be saved in the current directory :param table: Table to export :param format: Format to export in (default: parquet) :param path: Path to export to :return: None """ if format is None: format = "PARQUET" logger.debug(f"Exporting Table {table} as {format.upper()} to path {path}") if path is None: path = f"{table}.{format}" else: path = f"{path}/{table}.{format}" export_statement = f"COPY (SELECT * FROM {table}) TO '{path}' (FORMAT {format.upper()});" result = self.run_query(export_statement) logger.debug(f"Exported {table} to {path}/{table}") return result def load_local_path_to_table(self, path: str, table: Optional[str] = None) -> Tuple[str, str]: """Load a local file into duckdb :param path: Path to load :param table: Optional table name to use :return: Table name, SQL statement used to load the file """ import os logger.debug(f"Loading {path} into duckdb") if table is None: # Get the file name from the s3 path file_name = path.split("/")[-1] # Get the file name without extension from the s3 path table, extension = os.path.splitext(file_name) # If the table isn't a valid SQL identifier, we'll need to use something else table = table.replace("-", "_").replace(".", "_").replace(" ", "_").replace("/", "_") create_statement = f"CREATE OR REPLACE TABLE '{table}' AS SELECT * FROM '{path}';" self.run_query(create_statement) logger.debug(f"Loaded {path} into duckdb as {table}") return table, create_statement def load_local_csv_to_table( self, path: str, table: Optional[str] = None, delimiter: Optional[str] = None ) -> Tuple[str, str]: """Load a local CSV file into duckdb :param path: Path to load :param table: Optional table name to use :param delimiter: Optional delimiter to use :return: Table name, SQL statement used to load the file """ import os logger.debug(f"Loading {path} into duckdb") if table is None: # Get the file name from the s3 path file_name = path.split("/")[-1] # Get the file name without extension from the s3 path table, extension = os.path.splitext(file_name) # If the table isn't a valid SQL identifier, we'll need to use something else table = table.replace("-", "_").replace(".", "_").replace(" ", "_").replace("/", "_") select_statement = f"SELECT * FROM read_csv('{path}'" if delimiter is not None: select_statement += f", delim='{delimiter}')" else: select_statement += ")" create_statement = f"CREATE OR REPLACE TABLE '{table}' AS {select_statement};" self.run_query(create_statement) logger.debug(f"Loaded CSV {path} into duckdb as {table}") return table, create_statement def load_s3_path_to_table(self, path: str, table: Optional[str] = None) -> Tuple[str, str]: """Load a file from S3 into duckdb :param path: S3 path to load :param table: Optional table name to use :return: Table name, SQL statement used to load the file """ import os logger.debug(f"Loading {path} into duckdb") if table is None: # Get the file name from the s3 path file_name = path.split("/")[-1] # Get the file name without extension from the s3 path table, extension = os.path.splitext(file_name) # If the table isn't a valid SQL identifier, we'll need to use something else table = table.replace("-", "_").replace(".", "_").replace(" ", "_").replace("/", "_") create_statement = f"CREATE OR REPLACE TABLE '{table}' AS SELECT * FROM '{path}';" self.run_query(create_statement) logger.debug(f"Loaded {path} into duckdb as {table}") return table, create_statement def load_s3_csv_to_table( self, path: str, table: Optional[str] = None, delimiter: Optional[str] = None ) -> Tuple[str, str]: """Load a CSV file from S3 into duckdb :param path: S3 path to load :param table: Optional table name to use :return: Table name, SQL statement used to load the file """ import os logger.debug(f"Loading {path} into duckdb") if table is None: # Get the file name from the s3 path file_name = path.split("/")[-1] # Get the file name without extension from the s3 path table, extension = os.path.splitext(file_name) # If the table isn't a valid SQL identifier, we'll need to use something else table = table.replace("-", "_").replace(".", "_").replace(" ", "_").replace("/", "_") select_statement = f"SELECT * FROM read_csv('{path}'" if delimiter is not None: select_statement += f", delim='{delimiter}')" else: select_statement += ")" create_statement = f"CREATE OR REPLACE TABLE '{table}' AS {select_statement};" self.run_query(create_statement) logger.debug(f"Loaded CSV {path} into duckdb as {table}") return table, create_statement def create_fts_index(self, table: str, unique_key: str, input_values: list[str]) -> str: """Create a full text search index on a table :param table: Table to create the index on :param unique_key: Unique key to use :param input_values: Values to index :return: None """ logger.debug(f"Creating FTS index on {table} for {input_values}") self.run_query("INSTALL fts;") logger.debug("Installed FTS extension") self.run_query("LOAD fts;") logger.debug("Loaded FTS extension") create_fts_index_statement = f"PRAGMA create_fts_index('{table}', '{unique_key}', '{input_values}');" logger.debug(f"Running {create_fts_index_statement}") result = self.run_query(create_fts_index_statement) logger.debug(f"Created FTS index on {table} for {input_values}") return result def full_text_search(self, table: str, unique_key: str, search_text: str) -> str: """Full text Search in a table column for a specific text/keyword :param table: Table to search :param unique_key: Unique key to use :param search_text: Text to search :return: None """ logger.debug(f"Running full_text_search for {search_text} in {table}") search_text_statement = f"""SELECT fts_main_corpus.match_bm25({unique_key}, '{search_text}') AS score,* FROM {table} WHERE score IS NOT NULL ORDER BY score;""" logger.debug(f"Running {search_text_statement}") result = self.run_query(search_text_statement) logger.debug(f"Search results for {search_text} in {table}") return result