ankush-003 commited on
Commit
10757ec
·
1 Parent(s): cff415c
.Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as the base image
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ python3-dev \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Create and activate virtual environment
15
+ ENV VIRTUAL_ENV=/opt/venv
16
+ RUN python3 -m venv $VIRTUAL_ENV
17
+ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
18
+
19
+ # Copy requirements first to leverage Docker cache
20
+ COPY requirements.txt /app/requirements.txt
21
+
22
+ # Install dependencies in virtual environment
23
+ RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
24
+
25
+ # Copy the rest of the application
26
+ COPY . .
27
+
28
+ ENV PYTHONPATH=/app
29
+ ENV PYTHONUNBUFFERED=1
30
+
31
+ # Expose the default Chainlit port
32
+ EXPOSE 8000
33
+
34
+ # Command to run the application
35
+ # CMD . /opt/venv/bin/activate && exec chainlit run app.py --port 8000
36
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
.Dockerignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python venv
2
+ llm
3
+
4
+ # envs
5
+ .env
6
+
7
+ # idx
8
+ .idx
9
+ .vscode
10
+
11
+ # database
12
+ elections.db
13
+
14
+ # __pycache__
15
+ __pycache__
16
+
17
+ # Chainlit files
18
+ .chainlit
19
+ .files
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ from utils import load_details_dataset, load_election_dataset, load_maha_election_dataset
3
+ from sqlite3 import connect
4
+ from typing import cast
5
+ from utils.load_llm import load_llm
6
+ from dotenv import load_dotenv
7
+ from utils.query_generator import sql_generator, sql_formatter, analyze_results
8
+ from langchain.schema.runnable import Runnable
9
+ from utils.sql_runtime import SQLRuntime
10
+
11
+ load_dotenv()
12
+ # global variables
13
+ db_path = './data/elections.db'
14
+ sql_runtime = SQLRuntime(dbname=db_path)
15
+
16
+ # Load the dataset
17
+ @cl.action_callback("Load Datasets")
18
+ async def on_action(action: cl.Action):
19
+ print("Loading datasets...")
20
+
21
+ # save the datasets as tables
22
+ conn = connect('./data/elections.db')
23
+
24
+ load_details_dataset.load_data_from_csv_to_db('./data/details_of_assembly_segment_2019.csv', conn)
25
+ load_election_dataset.load_data_from_csv_to_db('./data/eci_data_2024.csv', conn)
26
+ load_maha_election_dataset.load_data_from_csv_to_db('./data/maha_results_2019.csv', conn)
27
+
28
+ return "Datasets loaded successfully."
29
+
30
+ @cl.action_callback("Execute Query")
31
+ async def on_action(action: cl.Action):
32
+ res = await cl.AskUserMessage(content="Enter Query to run Manually", timeout=20).send()
33
+ actions = [
34
+ cl.Action(name="Execute Query", description="Execute the query on the dataset", value="Execute Query")
35
+ ]
36
+ if res:
37
+ query = res['output']
38
+ res = sql_runtime.execute(query)
39
+ print(res)
40
+ if res["code"] == 0:
41
+ data = ""
42
+ if res["data"]:
43
+ for row in res["data"]:
44
+ data += str(row) + "\n"
45
+
46
+ elements = [
47
+ cl.Text(name="Result", content=data, display="inline"),
48
+ ]
49
+ await cl.Message(
50
+ content=f"Query: {query}",
51
+ elements=elements,
52
+ actions=actions,
53
+ ).send()
54
+ else:
55
+ error = res["msg"]["traceback"]
56
+ elements = [
57
+ cl.Text(name="Error", content=error, display="inline"),
58
+ ]
59
+ await cl.Message(
60
+ content=f"Query: {query}",
61
+ elements=elements,
62
+ actions=actions,
63
+ ).send()
64
+
65
+ # return "Query executed successfully."
66
+
67
+ @cl.on_chat_start
68
+ async def start():
69
+ # Sending an action button within a chatbot message
70
+ actions = [
71
+ cl.Action(name="Load Datasets", description="Load the datasets into the database", value="Load Datasets")
72
+ ]
73
+
74
+ chain = sql_generator | sql_formatter | analyze_results
75
+
76
+ cl.user_session.set("chain", chain)
77
+ cl.user_session.set("db_path", './data/elections.db')
78
+
79
+ await cl.Message(content="I am your personal political expert. I can help you analyze the election data. Click the button below to load the datasets.", actions=actions).send()
80
+
81
+ @cl.on_message
82
+ async def on_message(message: cl.Message):
83
+ chain = cast(Runnable, cl.user_session.get("chain"))
84
+ db_path = cl.user_session.get("db_path")
85
+
86
+ actions = [
87
+ cl.Action(name="Execute Query", description="Execute the query on the dataset", value="Execute Query")
88
+ ]
89
+
90
+ print(message.content)
91
+
92
+ try:
93
+ res = chain.invoke({
94
+ "query": message.content,
95
+ "db_path": db_path
96
+ })
97
+ except Exception as e:
98
+ print(e)
99
+ await cl.Message(content="An error occurred while processing the query. Please try again.").send()
100
+ return
101
+
102
+ queries = "\n".join(res.queries)
103
+
104
+ errors = "".join(res.errors)
105
+
106
+ elements = [
107
+ cl.Text(name='results', content=res.summary, display="inline"),
108
+ cl.Text(name="queries", content=queries, display="inline"),
109
+ ]
110
+
111
+ if errors:
112
+ elements.append(cl.Text(name="errors", content=errors, display="inline"))
113
+
114
+ await cl.Message(
115
+ content="Let's analyze the results of the query",
116
+ elements=elements,
117
+ actions=actions
118
+ ).send()
chainlit.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ## Welcome screen
2
+
3
+ # MahaNeta
4
+
5
+ **Your Own Personal Political Assistant**
data/details_of_assembly_segment_2019.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/eci_data_2024.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/maha_results_2019.csv ADDED
The diff for this file is too large to render. See raw diff
 
docs/lab_session1_25oct2024.pdf ADDED
Binary file (853 kB). View file
 
docs/pes_lab_session1.pdf ADDED
Binary file (236 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.3
2
+ aiohttp==3.10.10
3
+ aiosignal==1.3.1
4
+ altair==5.4.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ attrs==24.2.0
8
+ blinker==1.8.2
9
+ cachetools==5.5.0
10
+ certifi==2024.8.30
11
+ charset-normalizer==3.4.0
12
+ click==8.1.7
13
+ distro==1.9.0
14
+ frozenlist==1.5.0
15
+ gitdb==4.0.11
16
+ GitPython==3.1.43
17
+ greenlet==3.1.1
18
+ groq==0.11.0
19
+ h11==0.14.0
20
+ httpcore==1.0.6
21
+ httpx==0.27.2
22
+ idna==3.10
23
+ Jinja2==3.1.4
24
+ jsonpatch==1.33
25
+ jsonpointer==3.0.0
26
+ jsonschema==4.23.0
27
+ jsonschema-specifications==2024.10.1
28
+ langchain==0.3.4
29
+ langchain-core==0.3.13
30
+ langchain-groq==0.2.0
31
+ langchain-google-genai
32
+ langchain-text-splitters==0.3.0
33
+ langsmith==0.1.137
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==3.0.2
36
+ mdurl==0.1.2
37
+ multidict==6.1.0
38
+ narwhals==1.12.1
39
+ numpy==1.26.4
40
+ orjson==3.10.10
41
+ packaging==24.1
42
+ pandas==2.2.3
43
+ pillow==10.4.0
44
+ propcache==0.2.0
45
+ protobuf==5.28.3
46
+ pyarrow==18.0.0
47
+ pydantic==2.9.2
48
+ pydantic_core==2.23.4
49
+ pydeck==0.9.1
50
+ Pygments==2.18.0
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.0.1
53
+ pytz==2024.2
54
+ PyYAML==6.0.2
55
+ referencing==0.35.1
56
+ requests==2.32.3
57
+ requests-toolbelt==1.0.0
58
+ rich==13.9.3
59
+ rpds-py==0.20.0
60
+ six==1.16.0
61
+ smmap==5.0.1
62
+ sniffio==1.3.1
63
+ SQLAlchemy==2.0.36
64
+ streamlit==1.39.0
65
+ tenacity==9.0.0
66
+ toml==0.10.2
67
+ tornado==6.4.1
68
+ typing_extensions==4.12.2
69
+ tzdata==2024.2
70
+ urllib3==2.2.3
71
+ watchdog==5.0.3
72
+ yarl==1.17.0
73
+
74
+ # frontent
75
+ chainlit
utils/__init__.py ADDED
File without changes
utils/cot.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This example is to show Chain Of Thought
3
+ """
4
+
5
+ from langchain import PromptTemplate
6
+ from load_llm import load_llm
7
+
8
+ template = """Answer the question based on the context below. If the
9
+ question cannot be answered using the information provided answer
10
+ with "I don't know".
11
+
12
+ Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can contains 3 tennis balls. How many
13
+ tennis balls does he have now?
14
+ A: Roger started with 5 balls. 2 cans of 3 tennis balls is 6 tennis balls. 5+6 = 11.The answer is 11.
15
+
16
+ Q: The cafetaria has 23 apples. If they used 20 apples for lunch and bought 6 more, how many apples do they have?
17
+
18
+ """
19
+
20
+
21
+ prompt_template = PromptTemplate(
22
+ input_variables=[],
23
+ template=template
24
+ )
25
+
26
+ prompt = prompt_template.format(
27
+ )
28
+
29
+ llm = load_llm()
30
+ print(llm(prompt))
utils/few_shot.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This example is to show how to load an LLM, use a prompt and retrieve results
3
+ We illustrate the use of LangChain for a few shot inferencing
4
+
5
+ The dynamic number of examples is important because the max length of our prompt and completion output is limited.
6
+ This limitation is measured by maximum context window.
7
+
8
+ context_window = input_tokens + output_tokens
9
+ At the same time, we can maximize the number of examples given to the model for few-shot learning.
10
+
11
+ Considering this, we need to balance the number of examples included and our prompt size.
12
+ Our hard limit is the maximum context size, but we must also consider the cost of processing more tokens through LLM.
13
+ Fewer tokens mean a cheaper service and faster completions from the LLM.
14
+ """
15
+
16
+ from load_llm import load_llm
17
+ from langchain import PromptTemplate, FewShotPromptTemplate
18
+ from langchain.prompts.example_selector import LengthBasedExampleSelector
19
+
20
+ # create our examples
21
+ examples = [
22
+ {
23
+ "query": "How are you?",
24
+ "answer": "I can't complain but sometimes I still do."
25
+ }, {
26
+ "query": "What time is it?",
27
+ "answer": "It's time to get a watch."
28
+ }, {
29
+ "query": "What is the meaning of life?",
30
+ "answer": "42"
31
+ }, {
32
+ "query": "What is the weather like today?",
33
+ "answer": "Cloudy with a chance of memes."
34
+ }, {
35
+ "query": "What is your favorite movie?",
36
+ "answer": "Terminator"
37
+ }, {
38
+ "query": "Who is your best friend?",
39
+ "answer": "Siri. We have spirited debates about the meaning of life."
40
+ }, {
41
+ "query": "What should I do today?",
42
+ "answer": "Stop talking to chatbots on the internet and go outside."
43
+ }
44
+ ]
45
+
46
+ # create a example template
47
+ example_template = """
48
+ User: {query}
49
+ AI: {answer}
50
+ """
51
+
52
+ # create a prompt example from above template
53
+ example_prompt = PromptTemplate(
54
+ input_variables=["query", "answer"],
55
+ template=example_template
56
+ )
57
+
58
+
59
+ example_selector = LengthBasedExampleSelector(
60
+ examples=examples,
61
+ example_prompt=example_prompt,
62
+ max_length=50 # this sets the max length that examples should be
63
+ )
64
+
65
+ # now break our previous prompt into a prefix and suffix
66
+ # the prefix is our instructions
67
+ prefix = """The following are exerpts from conversations with an AI
68
+ assistant. The assistant is typically sarcastic and witty, producing
69
+ creative and funny responses to the users questions. Here are some
70
+ examples:
71
+ """
72
+ # and the suffix our user input and output indicator
73
+ suffix = """
74
+ User: {query}
75
+ AI: """
76
+
77
+ # now create the few shot prompt template
78
+ dynamic_prompt_template = FewShotPromptTemplate(
79
+ example_selector=example_selector, # use example_selector instead of examples
80
+ example_prompt=example_prompt,
81
+ prefix=prefix,
82
+ suffix=suffix,
83
+ input_variables=["query"],
84
+ example_separator="\n"
85
+ )
86
+
87
+ print(dynamic_prompt_template.format(query="How do birds fly?"))
88
+ print("-------- Longer query will select fewer examples in order to preserve the context ----------")
89
+
90
+ query = """If I am in America, and I want to call someone in another country, I'm
91
+ thinking maybe Europe, possibly western Europe like France, Germany, or the UK,
92
+ what is the best way to do that?"""
93
+
94
+ prompt = dynamic_prompt_template.format(query=query)
95
+ print(prompt)
96
+
97
+ print("-------- Shorter query for LLM ----------")
98
+ query = "How is the weather in your city today?"
99
+ prompt = dynamic_prompt_template.format(query=query)
100
+ print(prompt)
101
+
102
+ llm = load_llm()
103
+ print(
104
+ llm(prompt)
105
+ )
utils/get_completion_client.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from openai import OpenAI
3
+
4
+ # Point to the local server
5
+ client1 = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio") # html to json
6
+ model = r"lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q8_0.gguf"
7
+
8
+ # model = "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/Mistral-7B-Instruct-v0.3.Q4_K_M.gguf:2"
9
+
10
+
11
+ def get_completion(prompt, client=client1, model=model):
12
+ """
13
+ given the prompt, obtain the response from LLM hosted by LM Studio as a server
14
+ :param prompt: prompt to be sent to LLM server
15
+ :return: response from the LLM
16
+ """
17
+ prompt = [
18
+ {"role": "user", "content": prompt}
19
+ ]
20
+ completion = client.chat.completions.create(
21
+ model=model,
22
+ messages=prompt,
23
+ temperature=0.0,
24
+ stream=True,
25
+ )
26
+
27
+ new_message = {"role": "assistant", "content": ""}
28
+
29
+ for chunk in completion:
30
+ if chunk.choices[0].delta.content:
31
+ # print(chunk.choices[0].delta.content, end="", flush=True)
32
+ val = chunk.choices[0].delta.content
33
+ new_message["content"] += val
34
+
35
+ # print(type()
36
+ val = new_message["content"] # .split("<end_of_turn>")[0]
37
+
38
+ return val
39
+
40
+
41
+ if __name__ == '__main__':
42
+ prompt = """
43
+ You are a political leader and your party is trying to win the general elections in India.
44
+ You are given an LLM that can provide you the analytics using the past historical data given to it.
45
+ In particular the LLM has been provided data on which party won each constituency out of 545 and which assembly segment within the main constituency is more favorable.
46
+ It also has details of votes polled by every candidate.
47
+ Tell me 10 questions that you want to ask the LLM.
48
+ """
49
+ results = get_completion(prompt)
50
+ print(results)
51
+
utils/load_details_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import pandas as pd
4
+ import csv
5
+
6
+
7
+ def load_data_from_csv(name, end=58925):
8
+ data = []
9
+ keys = None
10
+ with open(name, "r", encoding="utf-8", errors="ignore") as f:
11
+ csv_data = csv.reader(f)
12
+ for i, line in enumerate(csv_data):
13
+ if i == 0:
14
+ keys = line
15
+ continue
16
+ item = {}
17
+ for key, val in zip(keys, line):
18
+ item[key] = val
19
+ data.append(item)
20
+ return data
21
+
22
+
23
+ def load_data_from_csv_to_db(name, conn, col_names=None):
24
+
25
+ # read the dataset from csv file and create a pandas dataframe
26
+ df = pd.read_csv(open(name, "r", encoding="utf-8", errors="ignore"))
27
+
28
+ df.columns = [
29
+ 'state', 'parliamentary_constituency', 'constituency', 'nota_votes', 'candidate_name', 'party_name', 'total_votes'
30
+ ]
31
+
32
+ # removing extra whitespace
33
+ string_columns = df.select_dtypes(include=['object']).columns
34
+ for col in string_columns:
35
+ df[col] = df[col].astype(str).str.strip()
36
+
37
+ df['constituency'] = df['constituency'].str.replace(r'\s*-\s*\d+$', '', regex=True)
38
+
39
+ # Remove any parenthetical suffixes like (SC) or (ST)
40
+ df['constituency'] = df['constituency'].str.replace(r'\s*\([^)]*\)', '', regex=True)
41
+
42
+ # save the dataframe as a database table, name of table is: elections_2019
43
+ result = df.to_sql("elections_2019", conn, if_exists="replace")
44
+
45
+ return result
46
+
47
+
48
+ def query_sql(conn, query):
49
+ cursor = conn.cursor()
50
+ cursor.execute(query)
51
+ result = cursor.fetchall()
52
+ field_names = [r[0] for r in cursor.description]
53
+ print(field_names)
54
+ return result
55
+
56
+
57
+ if __name__ == '__main__':
58
+ # create a connection to sql db called elections.db
59
+ conn = sqlite3.connect('../data/elections.db')
60
+
61
+ filename = r"../data/details_of_assembly_segment_2019.csv"
62
+
63
+ data = load_data_from_csv(filename, end=5)
64
+
65
+ res = load_data_from_csv_to_db(filename, conn)
66
+
67
+ query = "SELECT * FROM elections_2019 LIMIT 5;"
68
+ results = query_sql(conn, query)
69
+ print(results)
70
+
71
+ # keys = data.keys()
72
+ # for i, item in enumerate(data):
73
+ # print(data[item])
74
+ # jdata = json.loads(data.to_json())
75
+ # print(jdata)
utils/load_election_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import pandas as pd
4
+ import csv
5
+
6
+
7
+ def load_data_from_csv(name, end=58925):
8
+ data = []
9
+ keys = None
10
+ with open(name, "r", encoding="utf-8", errors="ignore") as f:
11
+ csv_data = csv.reader(f)
12
+ for i, line in enumerate(csv_data):
13
+ if i == 0:
14
+ keys = line
15
+ continue
16
+ item = {}
17
+ for key, val in zip(keys, line):
18
+ item[key] = val
19
+ data.append(item)
20
+ return data
21
+
22
+
23
+ def load_data_from_csv_to_db(name, conn, col_names=None):
24
+
25
+ # read the dataset from csv file and create a pandas dataframe
26
+ df = pd.read_csv(open(name, "r", encoding="utf-8", errors="ignore"))
27
+
28
+ df.columns = [
29
+ 'sn', 'candidate_name', 'party_name', 'evm_votes', 'postal_votes', 'total_votes', 'vote_percentage', 'state', 'constituency'
30
+ ]
31
+
32
+ df['constituency'] = df['constituency'].str.replace(r'\s*-\s*\d+$', '', regex=True)
33
+
34
+ # Remove any parenthetical suffixes like (SC) or (ST)
35
+ df['constituency'] = df['constituency'].str.replace(r'\s*\([^)]*\)', '', regex=True)
36
+
37
+ # save the dataframe as a database table, name of table is: elections_2019
38
+ result = df.to_sql("elections_2024", conn, if_exists="replace")
39
+
40
+ return result
41
+
42
+
43
+ def query_sql(conn, query):
44
+ cursor = conn.cursor()
45
+ cursor.execute(query)
46
+ result = cursor.fetchall()
47
+ field_names = [r[0] for r in cursor.description]
48
+ print(field_names)
49
+ return result
50
+
51
+
52
+ if __name__ == '__main__':
53
+ # create a connection to sql db called elections.db
54
+ conn = sqlite3.connect('../data/elections.db')
55
+
56
+ filename = r"../data/eci_data_2024.csv"
57
+
58
+
59
+ data = load_data_from_csv(filename, end=5)
60
+
61
+ res = load_data_from_csv_to_db(filename, conn)
62
+
63
+ query = "SELECT count(*) FROM elections_2024 WHERE constituency='Amalapuram';"
64
+ results = query_sql(conn, query)
65
+ print(results)
66
+
67
+ # keys = data.keys()
68
+ # for i, item in enumerate(data):
69
+ # print(data[item])
70
+ # jdata = json.loads(data.to_json())
71
+ # print(jdata)
72
+
utils/load_llm.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module loads the LLM from the local file system
3
+ Modify this file if you need to download some other model from Hugging Face or OpenAI/ChatGPT
4
+ """
5
+
6
+ # from langchain.llms import CTransformers
7
+ # from langchain_openai import OpenAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from dotenv import load_dotenv
11
+ import os
12
+
13
+ model_name = 'gemma2-9b-it'
14
+
15
+
16
+ def load_llm(model_name=model_name):
17
+ # llm = ChatGroq(
18
+ # temperature=0,
19
+ # model=model_name,
20
+ # )
21
+
22
+ llm = ChatGoogleGenerativeAI(
23
+ model="gemini-1.5-flash",
24
+ temperature=0,
25
+ max_tokens=None,
26
+ timeout=None,
27
+ max_retries=2,
28
+ )
29
+
30
+ return llm
31
+
32
+
33
+ if __name__ == '__main__':
34
+ load_dotenv()
35
+ llm = load_llm()
36
+ result = llm.invoke("Provide a short answer: What is machine learning?")
37
+ print(result.content)
utils/load_maha_election_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads Maharashtra assembly 2019 dataset
3
+ """
4
+ import json
5
+ import sqlite3
6
+ import pandas as pd
7
+ import csv
8
+
9
+ def load_data_from_csv(name, end=58925):
10
+ data = []
11
+ keys = None
12
+ with open(name, "r", encoding="utf-8") as f:
13
+ csv_data = csv.reader(f)
14
+ for i, line in enumerate(csv_data):
15
+ found = False
16
+ if i == 0:
17
+ keys = line
18
+ continue
19
+ for field in line:
20
+ if field.strip() == "TURNOUT":
21
+ found = True
22
+ break
23
+ if found:
24
+ # print("TURNOUT found, skipping")
25
+ continue
26
+ item = {}
27
+ # print(line)
28
+ for key, val in zip(keys, line):
29
+ item[key] = val
30
+ data.append(item)
31
+ return data
32
+
33
+ def clean_dataframe(df):
34
+ # Strip leading and trailing spaces from column names (without changing them)
35
+ df.columns = df.columns.str.strip()
36
+
37
+ # Strip spaces and convert text columns to lowercase
38
+ for col in df.select_dtypes(include='object').columns:
39
+ df[col] = df[col].str.strip()
40
+
41
+ # Fill null values with 0
42
+ df.fillna(0, inplace=True)
43
+
44
+ return df
45
+
46
+ def load_data_from_csv_to_db(name, conn):
47
+
48
+ # read the dataset from csv file and create a pandas dataframe
49
+ df = pd.read_csv(open(name, "r", encoding="utf-8"))
50
+
51
+ # clean the dataframe
52
+ df = clean_dataframe(df)
53
+
54
+ df.columns = [
55
+ 'state', 'constituency_number', 'constituency', 'candidate_name', 'sex', 'age',
56
+ 'category', 'party_name', 'party_symbol', 'evm_votes', 'postal_votes', 'total_votes',
57
+ 'vote_percentage', 'total_electors'
58
+ ]
59
+
60
+ # save the dataframe as a database table, name of table is: elections_2019
61
+ result = df.to_sql("maha_2019", conn, if_exists='replace', index=False)
62
+
63
+ return result
64
+
65
+
66
+ def query_sql(conn, query):
67
+ cursor = conn.cursor()
68
+ cursor.execute(query)
69
+ result = cursor.fetchall()
70
+ field_names = [r[0] for r in cursor.description]
71
+ print(field_names)
72
+ return result
73
+
74
+
75
+ if __name__ == '__main__':
76
+ # create a connection to sql db called elections.db
77
+ conn = sqlite3.connect('../data/elections.db')
78
+
79
+ filename = r"../data/maha_results_2019.csv"
80
+ data = load_data_from_csv(filename, end=5)
81
+ # print(data)
82
+
83
+ res = load_data_from_csv_to_db(filename, conn)
84
+ # print(res)
85
+
86
+ query = "SELECT * FROM maha_2019 LIMIT 5;"
87
+ results = query_sql(conn, query)
88
+ print(results)
89
+
90
+ # keys = data.keys()
91
+ # for i, item in enumerate(data):
92
+ # print(data[item])
93
+ # jdata = json.loads(data.to_json())
94
+ # print(jdata)
95
+
utils/prompts.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import PromptTemplate
2
+ from langchain import hub
3
+
4
+ # react prompt
5
+ react_prompt = hub.pull("hwchase17/react")
6
+
7
+ # prompt to generate sql queries
8
+
9
+ sql_query_prompt = PromptTemplate.from_template(
10
+ """
11
+ You are a SQL Query Agent who has access to a database with the schema:
12
+ {db_schema},
13
+ For the given input: {input},
14
+
15
+ Generate SQL queries by analyzing the schema and the input. Make sure to answer all the questions in the input.
16
+ Generate more number of queries so that a detailed analysis can be done. Make sure the queries are valid and safe.
17
+
18
+ If there is no relevant query to generate, just generate a query to view the schema of the tables.
19
+ """
20
+ )
21
+
22
+ # prompt to summarize the SQL query results
23
+ sql_query_summary_prompt = PromptTemplate.from_template(
24
+ """
25
+ You are a Political Expert who is analyzing the results of the SQL queries executed on the election database.
26
+ The initial query: {query},
27
+ You are provided with the sql queries and their results. Analyze the results and summarize the key insights and answer the initial query.
28
+ If there are any errors in the execution of queries, analyze the errors and provide insights on the issues.
29
+ {results}
30
+ """
31
+ )
32
+
33
+ sql_query_visualization_prompt = PromptTemplate.from_template(
34
+ """
35
+ You are a Data Scientist who is visualizing the results of the SQL queries executed on the election database.
36
+ The initial query: {query},
37
+ You are provided with the sql queries and their results. Visualize the results and provide insights on the data using appropriate visualizations and formatting.
38
+ If there are any errors in the execution of queries, analyze the errors and provide insights on the issues.
39
+ {results}
40
+ """
41
+ )
utils/query_generator.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sql_runtime import SQLRuntime
2
+ from pydantic import BaseModel, Field
3
+ from .load_llm import load_llm
4
+ from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt
5
+ from langchain_core.runnables import chain
6
+ from typing import Optional
7
+ from dotenv import load_dotenv
8
+
9
+ class Generated_query(BaseModel):
10
+ """
11
+ The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries
12
+ """
13
+ queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries")
14
+
15
+ class QuerySummary(BaseModel):
16
+ """
17
+ The summary of the SQL query results
18
+ """
19
+ summary: str = Field(description="The analysis of the SQL query results")
20
+ errors: list[str] = Field(description="The errors in the execution of the queries")
21
+ queries: list[str] = Field(description="The SQL queries executed and their results")
22
+
23
+ @chain
24
+ def sql_generator(input: dict) -> Generated_query:
25
+
26
+ query, db_path = input["query"], input["db_path"]
27
+
28
+ sql_runtime = SQLRuntime(dbname=db_path)
29
+
30
+ query_generator_llm = load_llm().with_structured_output(Generated_query)
31
+
32
+ # getting the schemas
33
+ schemas = sql_runtime.get_schemas()
34
+
35
+ # chain to generate the queries
36
+ chain = sql_query_prompt | query_generator_llm
37
+
38
+ # executing the chain
39
+ gen_queries = chain.invoke({
40
+ "db_schema": schemas,
41
+ "input": query
42
+ })
43
+
44
+ # executing the queries
45
+ res = sql_runtime.execute_batch(gen_queries.queries)
46
+
47
+ # print(res)
48
+
49
+ return {
50
+ "input": query,
51
+ "results": res
52
+ }
53
+
54
+ @chain
55
+ def sql_formatter(input):
56
+ """
57
+ Formats the output of the SQL queries
58
+ """
59
+ output = []
60
+ for item in input["results"]:
61
+ if item["code"] == 0:
62
+ output.append(f"Query: {item['msg']['input']}, Result: {item['data']}")
63
+ else:
64
+ output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}")
65
+
66
+ # print(output)
67
+
68
+ return {
69
+ "query": input["input"],
70
+ "results": output
71
+ }
72
+
73
+ @chain
74
+ def analyze_results(input) -> QuerySummary:
75
+ """
76
+ Analyzes the results of the SQL queries executed on the election database
77
+ """
78
+ chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary)
79
+
80
+ # chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary)
81
+
82
+ return chain.invoke({
83
+ "query": input["query"],
84
+ "results": input["results"]
85
+ })
86
+
87
+ if __name__ == '__main__':
88
+ load_dotenv()
89
+ # executing the queries
90
+ # results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019")
91
+
92
+ # for result in results:
93
+ # print(f"Query: {result['msg']['input']}")
94
+ # if result["code"] != 0:
95
+ # print(f"Error executing query: {result['msg']['reason']}")
96
+ # print(f"Traceback: {result['msg']['traceback']}")
97
+ # else:
98
+ # print(result["data"])
99
+ # print("\n")
100
+
101
+ # formatting the output
102
+ res = sql_generator | sql_formatter | analyze_results
103
+
104
+ formatted_output, formatted_output2 = res.invoke(
105
+ {
106
+ "query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols",
107
+ "db_path": "./data/elections.db"
108
+ }
109
+ )
110
+ print(formatted_output.summary)
111
+ print(formatted_output.errors)
112
+ print(formatted_output.queries)
113
+
114
+ print("\n")
115
+
116
+ print(formatted_output2.summary)
117
+ print(formatted_output2.errors)
118
+ print(formatted_output2.queries)
utils/react.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import hub
2
+ from langchain.agents import AgentExecutor, create_react_agent
3
+
4
+ def run_agent_executor(agent_executor: AgentExecutor, input_data: dict):
5
+ for chunk in agent_executor.stream(input_data):
6
+ if "actions" in chunk:
7
+ for action in chunk["actions"]:
8
+ print(f"Calling Tool: `{action.tool}` with input `{action.tool_input}`")
9
+ # Observation
10
+ elif "steps" in chunk:
11
+ for step in chunk["steps"]:
12
+ print(f"Tool Result: `{step.observation}`")
13
+ # Final result
14
+ elif "output" in chunk:
15
+ print(f'Final Output: {chunk["output"]}')
16
+ else:
17
+ raise ValueError()
18
+ print("---")
utils/sql_runtime.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runtime that accepts a sql statement and runs it on sql server.
3
+ Returns the results of sql execution.
4
+ """
5
+ import traceback
6
+ import sqlite3
7
+
8
+ # MODIFY THE PATH BELOW FOR YOUR SYSTEM
9
+ my_db = r"../data/elections.db"
10
+
11
+ class SQLRuntime(object):
12
+ def __init__(self, dbname=None):
13
+ if dbname is None:
14
+ dbname = my_db
15
+ conn = sqlite3.connect(dbname) # creating a connection
16
+ self.cursor = conn.cursor() # we need the cursor to execute statement
17
+ return
18
+
19
+ def list_tables(self):
20
+ result = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
21
+ table_names = sorted(list(zip(*result))[0])
22
+ return table_names
23
+
24
+ def get_schema_for_table(self, table_name):
25
+ result = self.cursor.execute("PRAGMA table_info('%s')" % table_name).fetchall()
26
+ column_names = list(zip(*result))[1]
27
+ return column_names
28
+
29
+ def get_schemas(self):
30
+ schemas = {}
31
+ table_names = self.list_tables()
32
+ for name in table_names:
33
+ fields = self.get_schema_for_table(name) # fields of the table name
34
+ schemas[name] = fields
35
+ return schemas
36
+
37
+ def execute(self, statement):
38
+ code = 0
39
+ msg = {
40
+ "text": "SUCCESS",
41
+ "reason": None,
42
+ "traceback": None,
43
+ }
44
+ data = None
45
+
46
+ try:
47
+ self.cursor.execute(statement)
48
+ except sqlite3.OperationalError:
49
+ code = -1
50
+ msg = {
51
+ "text": "ERROR: SQL execution error",
52
+ "reason": "possibly due to incorrect table/fields names",
53
+ "traceback": traceback.format_exc(),
54
+ }
55
+
56
+ if code == 0:
57
+ data = self.cursor.fetchall()
58
+
59
+ msg["input"] = statement
60
+
61
+ result = {
62
+ "code": code,
63
+ "msg": msg,
64
+ "data": data
65
+ }
66
+
67
+ return result
68
+
69
+ def execute_batch(self, queries):
70
+ results = []
71
+ for query in queries:
72
+ result = self.execute(query)
73
+ results.append(result)
74
+ return results
75
+
76
+ def post_process(self, data):
77
+ """
78
+ post process the data so that we can identify any harmful code and remove them.
79
+ Also, llm output may need an output parser.
80
+ :param data:
81
+ :return:
82
+ """
83
+ # IMPLEMENT YOUR CODE HERE FOR POST-PROCESSING and VALIDATION
84
+ return data
85
+
86
+
87
+ def sql_runtime(statement):
88
+ """
89
+ Instantiates a sql runtime and executes the given sql statement
90
+ :param statement: sql statement
91
+ """
92
+ SQL = SQLRuntime()
93
+ data = SQL.execute(statement)
94
+ return data
95
+
96
+
97
+ if __name__ == '__main__':
98
+ # stmt = """
99
+ # SELECT * FROM elections_2019;
100
+ # """
101
+ # stmt = input("Enter stmt: ")
102
+ sql = SQLRuntime()
103
+
104
+ tables = sql.list_tables()
105
+ print(tables)
106
+
107
+ schemas = {}
108
+ for table in tables:
109
+ schemas[table] = sql.get_schema_for_table(table)
110
+ print(f"Table: {table}, Schema: {schemas[table]}\n")
111
+
112
+ # data1 = sql.execute(stmt)
113
+
114
+ # dat = data1["data"]
115
+ # if dat is not None and len(dat) > 0:
116
+ # for record in dat:
117
+ # print(record)
118
+ # print("-" * 100)
119
+
120
+ # sample question: find out the votes polled by NOTA for each instance of Akkalkuwa in the parliamentary elections 2019.
121
+ stmt = """
122
+ SELECT party_name, SUM(nota_votes)
123
+ FROM elections_2019
124
+ WHERE constituency='Akkalkuwa'
125
+ GROUP BY party_name;
126
+ """
127
+
128
+ data1 = sql.execute(stmt)
129
+
130
+ # print(data1)
131
+
132
+ dat = data1["data"]
133
+ if dat is not None and len(dat) > 0:
134
+ for record in dat:
135
+ print(record)
136
+ print("-" * 100)
utils/tools.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional, Type
2
+ from langchain_core.tools import BaseTool
3
+ from pydantic import BaseModel, Field
4
+ import pandas as pd
5
+ from .sql_runtime import SQLRuntime
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from .load_llm import load_llm
9
+ from langchain_core.messages import SystemMessage
10
+ from langchain_core.prompts import HumanMessagePromptTemplate
11
+ from langchain.agents import AgentExecutor, create_react_agent
12
+ from dotenv import load_dotenv
13
+ from react import run_agent_executor
14
+ from prompts import react_prompt
15
+
16
+ # definig the input schema
17
+ class QueryInput(BaseModel):
18
+ query: str = Field(..., description="The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries")
19
+
20
+ class TableNameInput(BaseModel):
21
+ table_name: str = Field(..., description="The name of the table to analyze")
22
+
23
+ class ColumnSearchInput(BaseModel):
24
+ table_name: str = Field(..., description="The name of the table to search")
25
+ column_name: str = Field(..., description="The name of the column to search")
26
+ limit: int = Field(default=10, description="Maximum number of distinct values to return")
27
+
28
+ class SQLQueryTool(BaseTool):
29
+ name: str = "sql_query"
30
+ description: str = """
31
+ Execute a SQL query and return the results.
32
+ Use this when you need to run a specific SQL query on the elections database.
33
+ The query should be a valid SQL statement and should end with a semicolon.
34
+ There should be no harmful queries executed.
35
+ There are three tables in the database: elections_2019, elections_2024, maha_2019
36
+ """
37
+ args_schema: Type[BaseModel] = QueryInput
38
+
39
+ # def __init__(self, db_path: Optional[str] = None):
40
+ # super().__init__()
41
+ # self.
42
+
43
+ def _run(self, query: str) -> str:
44
+ sql_runtime = SQLRuntime('../data/elections.db')
45
+ try:
46
+ result = sql_runtime.execute(query)
47
+ if result["code"] != 0:
48
+ return f"Error executing query: {result['msg']['reason']}"
49
+
50
+ # Convert to DataFrame for nice string representation
51
+ df = pd.DataFrame(result["data"])
52
+ if not df.empty:
53
+ return df.to_string()
54
+ return "Query returned no results"
55
+
56
+ except Exception as e:
57
+ return f"Error: {str(e)}"
58
+
59
+ class TableInfoTool(BaseTool):
60
+ name: str = "get_table_info"
61
+ description: str = """
62
+ Get information about a specific table including its schema and basic statistics.
63
+ Use this when you need to understand the structure of a table or get basic statistics about it.
64
+ """
65
+ args_schema: Type[BaseModel] = TableNameInput
66
+
67
+ # def __init__(self, db_path: Optional[str] = None):
68
+ # super().__init__()
69
+
70
+
71
+ def _run(self, table_name: str) -> str:
72
+ sql_runtime = SQLRuntime('../data/elections.db')
73
+ try:
74
+ # Get schema
75
+ schema = sql_runtime.get_schema_for_table(table_name)
76
+
77
+ # Get row count
78
+ count_query = f"SELECT COUNT(*) FROM {table_name}"
79
+ count_result = sql_runtime.execute(count_query)
80
+ row_count = count_result["data"][0][0] if count_result["code"] == 0 else "Error"
81
+
82
+ # Get sample data
83
+ sample_query = f"SELECT * FROM {table_name} LIMIT 3"
84
+ sample_result = sql_runtime.execute(sample_query)
85
+
86
+ info = f"""
87
+ Table: {table_name}
88
+ Columns: {', '.join(schema)}
89
+ Row Count: {row_count}
90
+ Sample Data:
91
+ {pd.DataFrame(sample_result['data'], columns=schema).to_string() if sample_result['code'] == 0 else 'Error getting sample data'}
92
+ """
93
+ return info
94
+ except Exception as e:
95
+ return f"Error getting table info: {str(e)}"
96
+
97
+ class ColumnValuesTool(BaseTool):
98
+ name: str = "find_column_values"
99
+ description: str = """
100
+ Find distinct values in a specific column of a table.
101
+ Use this when you need to know what unique values exist in a particular column.
102
+ """
103
+ args_schema: Type[BaseModel] = ColumnSearchInput
104
+
105
+ # def __init__(self, db_path: Optional[str] = None):
106
+ # super().__init__()
107
+ # self.sql_runtime = SQLRuntime(db_path)
108
+
109
+ def _run(self, table_name: str, column_name: str, limit: int = 10) -> str:
110
+ sql_runtime = SQLRuntime('../data/elections.db')
111
+ try:
112
+ query = f"""
113
+ SELECT DISTINCT {column_name}
114
+ FROM {table_name}
115
+ LIMIT {limit}
116
+ """
117
+ result = sql_runtime.execute(query)
118
+ if result["code"] != 0:
119
+ return f"Error finding values: {result['msg']['reason']}"
120
+
121
+ values = [row[0] for row in result["data"]]
122
+ return f"Distinct values in {column_name}: {', '.join(map(str, values))}"
123
+ except Exception as e:
124
+ return f"Error: {str(e)}"
125
+
126
+ class ListTablesTool(BaseTool):
127
+ name: str = "list_tables"
128
+ description: str = """
129
+ List all available tables in the database.
130
+ Use this when you need to know what tables are available to query.
131
+ """
132
+
133
+ # def __init__(self, db_path: Optional[str] = None):
134
+ # super().__init__()
135
+ # self.sql_runtime = SQLRuntime(db_path)
136
+
137
+ def _run(self, *args, **kwargs) -> str:
138
+ sql_runtime = SQLRuntime('../data/elections.db')
139
+ try:
140
+ tables = sql_runtime.list_tables()
141
+ return f"Available tables: {', '.join(tables)}"
142
+ except Exception as e:
143
+ return f"Error listing tables: {str(e)}"
144
+
145
+ def create_sql_agent_tools(db_path: Optional[str] = '../data/elections.db') -> List[BaseTool]:
146
+ """
147
+ Create a list of all SQL tools for use with a Langchain agent.
148
+ """
149
+ return [
150
+ SQLQueryTool(),
151
+ TableInfoTool(),
152
+ # ColumnValuesTool(),
153
+ ListTablesTool()
154
+ ]
155
+
156
+ if __name__ == "__main__":
157
+ load_dotenv()
158
+ tools = create_sql_agent_tools()
159
+ for tool in tools:
160
+ print(f"Tool: {tool.name}")
161
+ print(f"Description: {tool.description}")
162
+ # print(f"Args Schema: {tool.args_schema.schema()}")
163
+
164
+
165
+ # prompt = prompt = ChatPromptTemplate.from_messages(
166
+ # [
167
+ # SystemMessage(
168
+ # content="""
169
+ # You are a sql agent who has access to a database with three tables: elections_2019, elections_2024, maha_2019.
170
+ # You can use the following tools:
171
+ # - sql_query: Execute a SQL query and return the results.
172
+ # - get_table_info: Get information about a specific table including its schema and basic statistics.
173
+ # - find_column_values: Find distinct values in a specific column of a table.
174
+ # - list_tables: List all available tables in the database.
175
+
176
+ # Answer the questions using the tools provided. Do not execute harmful queries.
177
+ # """
178
+ # ),
179
+ # HumanMessagePromptTemplate.from_template("{text}"),
180
+ # ]
181
+ # )
182
+
183
+
184
+ output_parser = StrOutputParser()
185
+
186
+ # Create the llm
187
+ llm = load_llm()
188
+
189
+ # llm.bind_tools(tools)
190
+
191
+ # res = llm.invoke("who won elections in maharashtra in Nandurbar in elections 2019? use the given tools")
192
+
193
+ # chain = prompt | llm | output_parser
194
+
195
+ # Run the chain
196
+ agent = create_react_agent(llm, tools, react_prompt)
197
+ # Create an agent executor by passing in the agent and tools
198
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
199
+
200
+ print("Agent created successfully")
201
+
202
+ # Run the agent
203
+ # agent_executor.invoke({"input": "Who won the elections in 2019 for the state maharashtra in constituency Akkalkuwa?"})
204
+
205
+ res = agent_executor.invoke({"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})
206
+
207
+ # run_agent_executor(agent_executor, {"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})