File size: 3,683 Bytes
10757ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from .sql_runtime import SQLRuntime
from pydantic import BaseModel, Field
from .load_llm import load_llm
from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt
from langchain_core.runnables import chain
from typing import Optional
from dotenv import load_dotenv

class Generated_query(BaseModel):
    """
    The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries
    """
    queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries")

class QuerySummary(BaseModel):
    """
    The summary of the SQL query results
    """
    summary: str = Field(description="The analysis of the SQL query results")
    errors: list[str] = Field(description="The errors in the execution of the queries")
    queries: list[str] = Field(description="The SQL queries executed and their results")

@chain
def sql_generator(input: dict) -> Generated_query:

    query, db_path = input["query"], input["db_path"]

    sql_runtime = SQLRuntime(dbname=db_path)

    query_generator_llm = load_llm().with_structured_output(Generated_query)

    # getting the schemas
    schemas = sql_runtime.get_schemas()

    # chain to generate the queries
    chain = sql_query_prompt | query_generator_llm

    # executing the chain
    gen_queries = chain.invoke({
        "db_schema": schemas,
        "input": query
    })

    # executing the queries
    res = sql_runtime.execute_batch(gen_queries.queries)

    # print(res)

    return {
        "input": query,
        "results": res
    }

@chain
def sql_formatter(input):
    """
    Formats the output of the SQL queries
    """
    output = []
    for item in input["results"]:
        if item["code"] == 0:
            output.append(f"Query: {item['msg']['input']}, Result: {item['data']}")
        else:
            output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}")

    # print(output)

    return {
        "query": input["input"],
        "results": output
    }

@chain
def analyze_results(input) -> QuerySummary:
    """
    Analyzes the results of the SQL queries executed on the election database
    """
    chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary)

    # chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary)

    return chain.invoke({
        "query": input["query"],
        "results": input["results"]
    })

if __name__ == '__main__':
    load_dotenv()
    # executing the queries
    # results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019")

    # for result in results:
    #     print(f"Query: {result['msg']['input']}")
    #     if result["code"] != 0:
    #         print(f"Error executing query: {result['msg']['reason']}")
    #         print(f"Traceback: {result['msg']['traceback']}")
    #     else:
    #         print(result["data"])
    #     print("\n")

    # formatting the output
    res = sql_generator | sql_formatter | analyze_results

    formatted_output, formatted_output2 = res.invoke(
        {
            "query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols",
            "db_path": "./data/elections.db"
        }
    )
    print(formatted_output.summary)
    print(formatted_output.errors)
    print(formatted_output.queries)

    print("\n")

    print(formatted_output2.summary)
    print(formatted_output2.errors)
    print(formatted_output2.queries)