AhmadT198's picture
Adding SQL
e537b1c verified
raw
history blame
1.56 kB
import sqlite3
from pydantic.v1 import BaseModel
from typing import List
from langchain.tools import Tool
conn = sqlite3.connect("db_2.sqlite")
def list_tables():
c= conn.cursor()
c.execute("SELECT name from sqlite_master WHERE type='table';")
rows = c.fetchall()
return "\n".join(row[0] for row in rows if row[0] is not None)
############################# Run SQLite Query Tool ##########################################
def run_sqlite_query(query):
c = conn.cursor()
try:
c.execute(query)
return c.fetchall()
except sqlite3.OperationalError as err:
return f"The following error occured: {str}"
class RunQuaryArgsSchema(BaseModel):
query: str
run_query_tool = Tool.from_function(
name="run_sqlite_query",
description="Run an sqlite query.",
func=run_sqlite_query,
args_schema=RunQuaryArgsSchema
)
############################# Describe Tables Tool ##########################################
def describe_tables(table_names):
c = conn.cursor()
tables = ', '.join("'" + table + "'" for table in table_names)
rows = c.execute(f"SELECT sql FROM sqlite_master WHERE type='table' and name IN ({tables});")
return "\n".join(row[0] for row in rows if row[0] is not None)
class DescribeTablesArgsSchema(BaseModel):
table_names: List[str]
describe_tables_tool = Tool.from_function(
name="describe_tables",
description="Given a list of table names, return the schema of those tables.",
func=describe_tables,
args_schema=DescribeTablesArgsSchema
)