from .sql_runtime import SQLRuntime from pydantic import BaseModel, Field from .load_llm import load_llm from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt from langchain_core.runnables import chain from typing import Optional from dotenv import load_dotenv class Generated_query(BaseModel): """ The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries """ queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries") class QuerySummary(BaseModel): """ The summary of the SQL query results """ summary: str = Field(description="The analysis of the SQL query results") errors: list[str] = Field(description="The errors in the execution of the queries") queries: list[str] = Field(description="The SQL queries executed and their results") @chain def sql_generator(input: dict) -> Generated_query: query, db_path = input["query"], input["db_path"] sql_runtime = SQLRuntime(dbname=db_path) query_generator_llm = load_llm().with_structured_output(Generated_query) # getting the schemas schemas = sql_runtime.get_schemas() # chain to generate the queries chain = sql_query_prompt | query_generator_llm # executing the chain gen_queries = chain.invoke({ "db_schema": schemas, "input": query }) # executing the queries res = sql_runtime.execute_batch(gen_queries.queries) # print(res) return { "input": query, "results": res } @chain def sql_formatter(input): """ Formats the output of the SQL queries """ output = [] for item in input["results"]: if item["code"] == 0: output.append(f"Query: {item['msg']['input']}, Result: {item['data']}") else: output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}") # print(output) return { "query": input["input"], "results": output } @chain def analyze_results(input) -> QuerySummary: """ Analyzes the results of the SQL queries executed on the election database """ chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary) # chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary) return chain.invoke({ "query": input["query"], "results": input["results"] }) if __name__ == '__main__': load_dotenv() # executing the queries # results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019") # for result in results: # print(f"Query: {result['msg']['input']}") # if result["code"] != 0: # print(f"Error executing query: {result['msg']['reason']}") # print(f"Traceback: {result['msg']['traceback']}") # else: # print(result["data"]) # print("\n") # formatting the output res = sql_generator | sql_formatter | analyze_results formatted_output, formatted_output2 = res.invoke( { "query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols", "db_path": "./data/elections.db" } ) print(formatted_output.summary) print(formatted_output.errors) print(formatted_output.queries) print("\n") print(formatted_output2.summary) print(formatted_output2.errors) print(formatted_output2.queries)