Spaces:
Runtime error
Runtime error
from typing import List, Dict, Any, Optional, Type | |
from langchain_core.tools import BaseTool | |
from pydantic import BaseModel, Field | |
import pandas as pd | |
from .sql_runtime import SQLRuntime | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from .load_llm import load_llm | |
from langchain_core.messages import SystemMessage | |
from langchain_core.prompts import HumanMessagePromptTemplate | |
from langchain.agents import AgentExecutor, create_react_agent | |
from dotenv import load_dotenv | |
from react import run_agent_executor | |
from prompts import react_prompt | |
# definig the input schema | |
class QueryInput(BaseModel): | |
query: str = Field(..., description="The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries") | |
class TableNameInput(BaseModel): | |
table_name: str = Field(..., description="The name of the table to analyze") | |
class ColumnSearchInput(BaseModel): | |
table_name: str = Field(..., description="The name of the table to search") | |
column_name: str = Field(..., description="The name of the column to search") | |
limit: int = Field(default=10, description="Maximum number of distinct values to return") | |
class SQLQueryTool(BaseTool): | |
name: str = "sql_query" | |
description: str = """ | |
Execute a SQL query and return the results. | |
Use this when you need to run a specific SQL query on the elections database. | |
The query should be a valid SQL statement and should end with a semicolon. | |
There should be no harmful queries executed. | |
There are three tables in the database: elections_2019, elections_2024, maha_2019 | |
""" | |
args_schema: Type[BaseModel] = QueryInput | |
# def __init__(self, db_path: Optional[str] = None): | |
# super().__init__() | |
# self. | |
def _run(self, query: str) -> str: | |
sql_runtime = SQLRuntime('../data/elections.db') | |
try: | |
result = sql_runtime.execute(query) | |
if result["code"] != 0: | |
return f"Error executing query: {result['msg']['reason']}" | |
# Convert to DataFrame for nice string representation | |
df = pd.DataFrame(result["data"]) | |
if not df.empty: | |
return df.to_string() | |
return "Query returned no results" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
class TableInfoTool(BaseTool): | |
name: str = "get_table_info" | |
description: str = """ | |
Get information about a specific table including its schema and basic statistics. | |
Use this when you need to understand the structure of a table or get basic statistics about it. | |
""" | |
args_schema: Type[BaseModel] = TableNameInput | |
# def __init__(self, db_path: Optional[str] = None): | |
# super().__init__() | |
def _run(self, table_name: str) -> str: | |
sql_runtime = SQLRuntime('../data/elections.db') | |
try: | |
# Get schema | |
schema = sql_runtime.get_schema_for_table(table_name) | |
# Get row count | |
count_query = f"SELECT COUNT(*) FROM {table_name}" | |
count_result = sql_runtime.execute(count_query) | |
row_count = count_result["data"][0][0] if count_result["code"] == 0 else "Error" | |
# Get sample data | |
sample_query = f"SELECT * FROM {table_name} LIMIT 3" | |
sample_result = sql_runtime.execute(sample_query) | |
info = f""" | |
Table: {table_name} | |
Columns: {', '.join(schema)} | |
Row Count: {row_count} | |
Sample Data: | |
{pd.DataFrame(sample_result['data'], columns=schema).to_string() if sample_result['code'] == 0 else 'Error getting sample data'} | |
""" | |
return info | |
except Exception as e: | |
return f"Error getting table info: {str(e)}" | |
class ColumnValuesTool(BaseTool): | |
name: str = "find_column_values" | |
description: str = """ | |
Find distinct values in a specific column of a table. | |
Use this when you need to know what unique values exist in a particular column. | |
""" | |
args_schema: Type[BaseModel] = ColumnSearchInput | |
# def __init__(self, db_path: Optional[str] = None): | |
# super().__init__() | |
# self.sql_runtime = SQLRuntime(db_path) | |
def _run(self, table_name: str, column_name: str, limit: int = 10) -> str: | |
sql_runtime = SQLRuntime('../data/elections.db') | |
try: | |
query = f""" | |
SELECT DISTINCT {column_name} | |
FROM {table_name} | |
LIMIT {limit} | |
""" | |
result = sql_runtime.execute(query) | |
if result["code"] != 0: | |
return f"Error finding values: {result['msg']['reason']}" | |
values = [row[0] for row in result["data"]] | |
return f"Distinct values in {column_name}: {', '.join(map(str, values))}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
class ListTablesTool(BaseTool): | |
name: str = "list_tables" | |
description: str = """ | |
List all available tables in the database. | |
Use this when you need to know what tables are available to query. | |
""" | |
# def __init__(self, db_path: Optional[str] = None): | |
# super().__init__() | |
# self.sql_runtime = SQLRuntime(db_path) | |
def _run(self, *args, **kwargs) -> str: | |
sql_runtime = SQLRuntime('../data/elections.db') | |
try: | |
tables = sql_runtime.list_tables() | |
return f"Available tables: {', '.join(tables)}" | |
except Exception as e: | |
return f"Error listing tables: {str(e)}" | |
def create_sql_agent_tools(db_path: Optional[str] = '../data/elections.db') -> List[BaseTool]: | |
""" | |
Create a list of all SQL tools for use with a Langchain agent. | |
""" | |
return [ | |
SQLQueryTool(), | |
TableInfoTool(), | |
# ColumnValuesTool(), | |
ListTablesTool() | |
] | |
if __name__ == "__main__": | |
load_dotenv() | |
tools = create_sql_agent_tools() | |
for tool in tools: | |
print(f"Tool: {tool.name}") | |
print(f"Description: {tool.description}") | |
# print(f"Args Schema: {tool.args_schema.schema()}") | |
# prompt = prompt = ChatPromptTemplate.from_messages( | |
# [ | |
# SystemMessage( | |
# content=""" | |
# You are a sql agent who has access to a database with three tables: elections_2019, elections_2024, maha_2019. | |
# You can use the following tools: | |
# - sql_query: Execute a SQL query and return the results. | |
# - get_table_info: Get information about a specific table including its schema and basic statistics. | |
# - find_column_values: Find distinct values in a specific column of a table. | |
# - list_tables: List all available tables in the database. | |
# Answer the questions using the tools provided. Do not execute harmful queries. | |
# """ | |
# ), | |
# HumanMessagePromptTemplate.from_template("{text}"), | |
# ] | |
# ) | |
output_parser = StrOutputParser() | |
# Create the llm | |
llm = load_llm() | |
# llm.bind_tools(tools) | |
# res = llm.invoke("who won elections in maharashtra in Nandurbar in elections 2019? use the given tools") | |
# chain = prompt | llm | output_parser | |
# Run the chain | |
agent = create_react_agent(llm, tools, react_prompt) | |
# Create an agent executor by passing in the agent and tools | |
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) | |
print("Agent created successfully") | |
# Run the agent | |
# agent_executor.invoke({"input": "Who won the elections in 2019 for the state maharashtra in constituency Akkalkuwa?"}) | |
res = agent_executor.invoke({"input": "who won elections in maharashtra in Nandurbar in elections 2019?"}) | |
# run_agent_executor(agent_executor, {"input": "who won elections in maharashtra in Nandurbar in elections 2019?"}) | |