MahaNeta / utils /tools.py
ankush-003's picture
init
10757ec
from typing import List, Dict, Any, Optional, Type
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
import pandas as pd
from .sql_runtime import SQLRuntime
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from .load_llm import load_llm
from langchain_core.messages import SystemMessage
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain.agents import AgentExecutor, create_react_agent
from dotenv import load_dotenv
from react import run_agent_executor
from prompts import react_prompt
# definig the input schema
class QueryInput(BaseModel):
query: str = Field(..., description="The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries")
class TableNameInput(BaseModel):
table_name: str = Field(..., description="The name of the table to analyze")
class ColumnSearchInput(BaseModel):
table_name: str = Field(..., description="The name of the table to search")
column_name: str = Field(..., description="The name of the column to search")
limit: int = Field(default=10, description="Maximum number of distinct values to return")
class SQLQueryTool(BaseTool):
name: str = "sql_query"
description: str = """
Execute a SQL query and return the results.
Use this when you need to run a specific SQL query on the elections database.
The query should be a valid SQL statement and should end with a semicolon.
There should be no harmful queries executed.
There are three tables in the database: elections_2019, elections_2024, maha_2019
"""
args_schema: Type[BaseModel] = QueryInput
# def __init__(self, db_path: Optional[str] = None):
# super().__init__()
# self.
def _run(self, query: str) -> str:
sql_runtime = SQLRuntime('../data/elections.db')
try:
result = sql_runtime.execute(query)
if result["code"] != 0:
return f"Error executing query: {result['msg']['reason']}"
# Convert to DataFrame for nice string representation
df = pd.DataFrame(result["data"])
if not df.empty:
return df.to_string()
return "Query returned no results"
except Exception as e:
return f"Error: {str(e)}"
class TableInfoTool(BaseTool):
name: str = "get_table_info"
description: str = """
Get information about a specific table including its schema and basic statistics.
Use this when you need to understand the structure of a table or get basic statistics about it.
"""
args_schema: Type[BaseModel] = TableNameInput
# def __init__(self, db_path: Optional[str] = None):
# super().__init__()
def _run(self, table_name: str) -> str:
sql_runtime = SQLRuntime('../data/elections.db')
try:
# Get schema
schema = sql_runtime.get_schema_for_table(table_name)
# Get row count
count_query = f"SELECT COUNT(*) FROM {table_name}"
count_result = sql_runtime.execute(count_query)
row_count = count_result["data"][0][0] if count_result["code"] == 0 else "Error"
# Get sample data
sample_query = f"SELECT * FROM {table_name} LIMIT 3"
sample_result = sql_runtime.execute(sample_query)
info = f"""
Table: {table_name}
Columns: {', '.join(schema)}
Row Count: {row_count}
Sample Data:
{pd.DataFrame(sample_result['data'], columns=schema).to_string() if sample_result['code'] == 0 else 'Error getting sample data'}
"""
return info
except Exception as e:
return f"Error getting table info: {str(e)}"
class ColumnValuesTool(BaseTool):
name: str = "find_column_values"
description: str = """
Find distinct values in a specific column of a table.
Use this when you need to know what unique values exist in a particular column.
"""
args_schema: Type[BaseModel] = ColumnSearchInput
# def __init__(self, db_path: Optional[str] = None):
# super().__init__()
# self.sql_runtime = SQLRuntime(db_path)
def _run(self, table_name: str, column_name: str, limit: int = 10) -> str:
sql_runtime = SQLRuntime('../data/elections.db')
try:
query = f"""
SELECT DISTINCT {column_name}
FROM {table_name}
LIMIT {limit}
"""
result = sql_runtime.execute(query)
if result["code"] != 0:
return f"Error finding values: {result['msg']['reason']}"
values = [row[0] for row in result["data"]]
return f"Distinct values in {column_name}: {', '.join(map(str, values))}"
except Exception as e:
return f"Error: {str(e)}"
class ListTablesTool(BaseTool):
name: str = "list_tables"
description: str = """
List all available tables in the database.
Use this when you need to know what tables are available to query.
"""
# def __init__(self, db_path: Optional[str] = None):
# super().__init__()
# self.sql_runtime = SQLRuntime(db_path)
def _run(self, *args, **kwargs) -> str:
sql_runtime = SQLRuntime('../data/elections.db')
try:
tables = sql_runtime.list_tables()
return f"Available tables: {', '.join(tables)}"
except Exception as e:
return f"Error listing tables: {str(e)}"
def create_sql_agent_tools(db_path: Optional[str] = '../data/elections.db') -> List[BaseTool]:
"""
Create a list of all SQL tools for use with a Langchain agent.
"""
return [
SQLQueryTool(),
TableInfoTool(),
# ColumnValuesTool(),
ListTablesTool()
]
if __name__ == "__main__":
load_dotenv()
tools = create_sql_agent_tools()
for tool in tools:
print(f"Tool: {tool.name}")
print(f"Description: {tool.description}")
# print(f"Args Schema: {tool.args_schema.schema()}")
# prompt = prompt = ChatPromptTemplate.from_messages(
# [
# SystemMessage(
# content="""
# You are a sql agent who has access to a database with three tables: elections_2019, elections_2024, maha_2019.
# You can use the following tools:
# - sql_query: Execute a SQL query and return the results.
# - get_table_info: Get information about a specific table including its schema and basic statistics.
# - find_column_values: Find distinct values in a specific column of a table.
# - list_tables: List all available tables in the database.
# Answer the questions using the tools provided. Do not execute harmful queries.
# """
# ),
# HumanMessagePromptTemplate.from_template("{text}"),
# ]
# )
output_parser = StrOutputParser()
# Create the llm
llm = load_llm()
# llm.bind_tools(tools)
# res = llm.invoke("who won elections in maharashtra in Nandurbar in elections 2019? use the given tools")
# chain = prompt | llm | output_parser
# Run the chain
agent = create_react_agent(llm, tools, react_prompt)
# Create an agent executor by passing in the agent and tools
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
print("Agent created successfully")
# Run the agent
# agent_executor.invoke({"input": "Who won the elections in 2019 for the state maharashtra in constituency Akkalkuwa?"})
res = agent_executor.invoke({"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})
# run_agent_executor(agent_executor, {"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})