Spaces:
Runtime error
Runtime error
ankush-003
commited on
Commit
·
10757ec
1
Parent(s):
cff415c
init
Browse files- .Dockerfile +36 -0
- .Dockerignore +19 -0
- app.py +118 -0
- chainlit.md +5 -0
- data/details_of_assembly_segment_2019.csv +0 -0
- data/eci_data_2024.csv +0 -0
- data/maha_results_2019.csv +0 -0
- docs/lab_session1_25oct2024.pdf +0 -0
- docs/pes_lab_session1.pdf +0 -0
- requirements.txt +75 -0
- utils/__init__.py +0 -0
- utils/cot.py +30 -0
- utils/few_shot.py +105 -0
- utils/get_completion_client.py +51 -0
- utils/load_details_dataset.py +75 -0
- utils/load_election_dataset.py +72 -0
- utils/load_llm.py +37 -0
- utils/load_maha_election_dataset.py +95 -0
- utils/prompts.py +41 -0
- utils/query_generator.py +118 -0
- utils/react.py +18 -0
- utils/sql_runtime.py +136 -0
- utils/tools.py +207 -0
.Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Python 3.9 as the base image
|
2 |
+
FROM python:3.11-slim
|
3 |
+
|
4 |
+
# Set working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install system dependencies
|
8 |
+
RUN apt-get update && \
|
9 |
+
apt-get install -y --no-install-recommends \
|
10 |
+
build-essential \
|
11 |
+
python3-dev \
|
12 |
+
&& rm -rf /var/lib/apt/lists/*
|
13 |
+
|
14 |
+
# Create and activate virtual environment
|
15 |
+
ENV VIRTUAL_ENV=/opt/venv
|
16 |
+
RUN python3 -m venv $VIRTUAL_ENV
|
17 |
+
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
18 |
+
|
19 |
+
# Copy requirements first to leverage Docker cache
|
20 |
+
COPY requirements.txt /app/requirements.txt
|
21 |
+
|
22 |
+
# Install dependencies in virtual environment
|
23 |
+
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
|
24 |
+
|
25 |
+
# Copy the rest of the application
|
26 |
+
COPY . .
|
27 |
+
|
28 |
+
ENV PYTHONPATH=/app
|
29 |
+
ENV PYTHONUNBUFFERED=1
|
30 |
+
|
31 |
+
# Expose the default Chainlit port
|
32 |
+
EXPOSE 8000
|
33 |
+
|
34 |
+
# Command to run the application
|
35 |
+
# CMD . /opt/venv/bin/activate && exec chainlit run app.py --port 8000
|
36 |
+
CMD ["chainlit", "run", "app.py", "--port", "7860"]
|
.Dockerignore
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python venv
|
2 |
+
llm
|
3 |
+
|
4 |
+
# envs
|
5 |
+
.env
|
6 |
+
|
7 |
+
# idx
|
8 |
+
.idx
|
9 |
+
.vscode
|
10 |
+
|
11 |
+
# database
|
12 |
+
elections.db
|
13 |
+
|
14 |
+
# __pycache__
|
15 |
+
__pycache__
|
16 |
+
|
17 |
+
# Chainlit files
|
18 |
+
.chainlit
|
19 |
+
.files
|
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainlit as cl
|
2 |
+
from utils import load_details_dataset, load_election_dataset, load_maha_election_dataset
|
3 |
+
from sqlite3 import connect
|
4 |
+
from typing import cast
|
5 |
+
from utils.load_llm import load_llm
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from utils.query_generator import sql_generator, sql_formatter, analyze_results
|
8 |
+
from langchain.schema.runnable import Runnable
|
9 |
+
from utils.sql_runtime import SQLRuntime
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
# global variables
|
13 |
+
db_path = './data/elections.db'
|
14 |
+
sql_runtime = SQLRuntime(dbname=db_path)
|
15 |
+
|
16 |
+
# Load the dataset
|
17 |
+
@cl.action_callback("Load Datasets")
|
18 |
+
async def on_action(action: cl.Action):
|
19 |
+
print("Loading datasets...")
|
20 |
+
|
21 |
+
# save the datasets as tables
|
22 |
+
conn = connect('./data/elections.db')
|
23 |
+
|
24 |
+
load_details_dataset.load_data_from_csv_to_db('./data/details_of_assembly_segment_2019.csv', conn)
|
25 |
+
load_election_dataset.load_data_from_csv_to_db('./data/eci_data_2024.csv', conn)
|
26 |
+
load_maha_election_dataset.load_data_from_csv_to_db('./data/maha_results_2019.csv', conn)
|
27 |
+
|
28 |
+
return "Datasets loaded successfully."
|
29 |
+
|
30 |
+
@cl.action_callback("Execute Query")
|
31 |
+
async def on_action(action: cl.Action):
|
32 |
+
res = await cl.AskUserMessage(content="Enter Query to run Manually", timeout=20).send()
|
33 |
+
actions = [
|
34 |
+
cl.Action(name="Execute Query", description="Execute the query on the dataset", value="Execute Query")
|
35 |
+
]
|
36 |
+
if res:
|
37 |
+
query = res['output']
|
38 |
+
res = sql_runtime.execute(query)
|
39 |
+
print(res)
|
40 |
+
if res["code"] == 0:
|
41 |
+
data = ""
|
42 |
+
if res["data"]:
|
43 |
+
for row in res["data"]:
|
44 |
+
data += str(row) + "\n"
|
45 |
+
|
46 |
+
elements = [
|
47 |
+
cl.Text(name="Result", content=data, display="inline"),
|
48 |
+
]
|
49 |
+
await cl.Message(
|
50 |
+
content=f"Query: {query}",
|
51 |
+
elements=elements,
|
52 |
+
actions=actions,
|
53 |
+
).send()
|
54 |
+
else:
|
55 |
+
error = res["msg"]["traceback"]
|
56 |
+
elements = [
|
57 |
+
cl.Text(name="Error", content=error, display="inline"),
|
58 |
+
]
|
59 |
+
await cl.Message(
|
60 |
+
content=f"Query: {query}",
|
61 |
+
elements=elements,
|
62 |
+
actions=actions,
|
63 |
+
).send()
|
64 |
+
|
65 |
+
# return "Query executed successfully."
|
66 |
+
|
67 |
+
@cl.on_chat_start
|
68 |
+
async def start():
|
69 |
+
# Sending an action button within a chatbot message
|
70 |
+
actions = [
|
71 |
+
cl.Action(name="Load Datasets", description="Load the datasets into the database", value="Load Datasets")
|
72 |
+
]
|
73 |
+
|
74 |
+
chain = sql_generator | sql_formatter | analyze_results
|
75 |
+
|
76 |
+
cl.user_session.set("chain", chain)
|
77 |
+
cl.user_session.set("db_path", './data/elections.db')
|
78 |
+
|
79 |
+
await cl.Message(content="I am your personal political expert. I can help you analyze the election data. Click the button below to load the datasets.", actions=actions).send()
|
80 |
+
|
81 |
+
@cl.on_message
|
82 |
+
async def on_message(message: cl.Message):
|
83 |
+
chain = cast(Runnable, cl.user_session.get("chain"))
|
84 |
+
db_path = cl.user_session.get("db_path")
|
85 |
+
|
86 |
+
actions = [
|
87 |
+
cl.Action(name="Execute Query", description="Execute the query on the dataset", value="Execute Query")
|
88 |
+
]
|
89 |
+
|
90 |
+
print(message.content)
|
91 |
+
|
92 |
+
try:
|
93 |
+
res = chain.invoke({
|
94 |
+
"query": message.content,
|
95 |
+
"db_path": db_path
|
96 |
+
})
|
97 |
+
except Exception as e:
|
98 |
+
print(e)
|
99 |
+
await cl.Message(content="An error occurred while processing the query. Please try again.").send()
|
100 |
+
return
|
101 |
+
|
102 |
+
queries = "\n".join(res.queries)
|
103 |
+
|
104 |
+
errors = "".join(res.errors)
|
105 |
+
|
106 |
+
elements = [
|
107 |
+
cl.Text(name='results', content=res.summary, display="inline"),
|
108 |
+
cl.Text(name="queries", content=queries, display="inline"),
|
109 |
+
]
|
110 |
+
|
111 |
+
if errors:
|
112 |
+
elements.append(cl.Text(name="errors", content=errors, display="inline"))
|
113 |
+
|
114 |
+
await cl.Message(
|
115 |
+
content="Let's analyze the results of the query",
|
116 |
+
elements=elements,
|
117 |
+
actions=actions
|
118 |
+
).send()
|
chainlit.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Welcome screen
|
2 |
+
|
3 |
+
# MahaNeta
|
4 |
+
|
5 |
+
**Your Own Personal Political Assistant**
|
data/details_of_assembly_segment_2019.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/eci_data_2024.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/maha_results_2019.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/lab_session1_25oct2024.pdf
ADDED
Binary file (853 kB). View file
|
|
docs/pes_lab_session1.pdf
ADDED
Binary file (236 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.4.3
|
2 |
+
aiohttp==3.10.10
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.4.1
|
5 |
+
annotated-types==0.7.0
|
6 |
+
anyio==4.6.2.post1
|
7 |
+
attrs==24.2.0
|
8 |
+
blinker==1.8.2
|
9 |
+
cachetools==5.5.0
|
10 |
+
certifi==2024.8.30
|
11 |
+
charset-normalizer==3.4.0
|
12 |
+
click==8.1.7
|
13 |
+
distro==1.9.0
|
14 |
+
frozenlist==1.5.0
|
15 |
+
gitdb==4.0.11
|
16 |
+
GitPython==3.1.43
|
17 |
+
greenlet==3.1.1
|
18 |
+
groq==0.11.0
|
19 |
+
h11==0.14.0
|
20 |
+
httpcore==1.0.6
|
21 |
+
httpx==0.27.2
|
22 |
+
idna==3.10
|
23 |
+
Jinja2==3.1.4
|
24 |
+
jsonpatch==1.33
|
25 |
+
jsonpointer==3.0.0
|
26 |
+
jsonschema==4.23.0
|
27 |
+
jsonschema-specifications==2024.10.1
|
28 |
+
langchain==0.3.4
|
29 |
+
langchain-core==0.3.13
|
30 |
+
langchain-groq==0.2.0
|
31 |
+
langchain-google-genai
|
32 |
+
langchain-text-splitters==0.3.0
|
33 |
+
langsmith==0.1.137
|
34 |
+
markdown-it-py==3.0.0
|
35 |
+
MarkupSafe==3.0.2
|
36 |
+
mdurl==0.1.2
|
37 |
+
multidict==6.1.0
|
38 |
+
narwhals==1.12.1
|
39 |
+
numpy==1.26.4
|
40 |
+
orjson==3.10.10
|
41 |
+
packaging==24.1
|
42 |
+
pandas==2.2.3
|
43 |
+
pillow==10.4.0
|
44 |
+
propcache==0.2.0
|
45 |
+
protobuf==5.28.3
|
46 |
+
pyarrow==18.0.0
|
47 |
+
pydantic==2.9.2
|
48 |
+
pydantic_core==2.23.4
|
49 |
+
pydeck==0.9.1
|
50 |
+
Pygments==2.18.0
|
51 |
+
python-dateutil==2.9.0.post0
|
52 |
+
python-dotenv==1.0.1
|
53 |
+
pytz==2024.2
|
54 |
+
PyYAML==6.0.2
|
55 |
+
referencing==0.35.1
|
56 |
+
requests==2.32.3
|
57 |
+
requests-toolbelt==1.0.0
|
58 |
+
rich==13.9.3
|
59 |
+
rpds-py==0.20.0
|
60 |
+
six==1.16.0
|
61 |
+
smmap==5.0.1
|
62 |
+
sniffio==1.3.1
|
63 |
+
SQLAlchemy==2.0.36
|
64 |
+
streamlit==1.39.0
|
65 |
+
tenacity==9.0.0
|
66 |
+
toml==0.10.2
|
67 |
+
tornado==6.4.1
|
68 |
+
typing_extensions==4.12.2
|
69 |
+
tzdata==2024.2
|
70 |
+
urllib3==2.2.3
|
71 |
+
watchdog==5.0.3
|
72 |
+
yarl==1.17.0
|
73 |
+
|
74 |
+
# frontent
|
75 |
+
chainlit
|
utils/__init__.py
ADDED
File without changes
|
utils/cot.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This example is to show Chain Of Thought
|
3 |
+
"""
|
4 |
+
|
5 |
+
from langchain import PromptTemplate
|
6 |
+
from load_llm import load_llm
|
7 |
+
|
8 |
+
template = """Answer the question based on the context below. If the
|
9 |
+
question cannot be answered using the information provided answer
|
10 |
+
with "I don't know".
|
11 |
+
|
12 |
+
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can contains 3 tennis balls. How many
|
13 |
+
tennis balls does he have now?
|
14 |
+
A: Roger started with 5 balls. 2 cans of 3 tennis balls is 6 tennis balls. 5+6 = 11.The answer is 11.
|
15 |
+
|
16 |
+
Q: The cafetaria has 23 apples. If they used 20 apples for lunch and bought 6 more, how many apples do they have?
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
prompt_template = PromptTemplate(
|
22 |
+
input_variables=[],
|
23 |
+
template=template
|
24 |
+
)
|
25 |
+
|
26 |
+
prompt = prompt_template.format(
|
27 |
+
)
|
28 |
+
|
29 |
+
llm = load_llm()
|
30 |
+
print(llm(prompt))
|
utils/few_shot.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This example is to show how to load an LLM, use a prompt and retrieve results
|
3 |
+
We illustrate the use of LangChain for a few shot inferencing
|
4 |
+
|
5 |
+
The dynamic number of examples is important because the max length of our prompt and completion output is limited.
|
6 |
+
This limitation is measured by maximum context window.
|
7 |
+
|
8 |
+
context_window = input_tokens + output_tokens
|
9 |
+
At the same time, we can maximize the number of examples given to the model for few-shot learning.
|
10 |
+
|
11 |
+
Considering this, we need to balance the number of examples included and our prompt size.
|
12 |
+
Our hard limit is the maximum context size, but we must also consider the cost of processing more tokens through LLM.
|
13 |
+
Fewer tokens mean a cheaper service and faster completions from the LLM.
|
14 |
+
"""
|
15 |
+
|
16 |
+
from load_llm import load_llm
|
17 |
+
from langchain import PromptTemplate, FewShotPromptTemplate
|
18 |
+
from langchain.prompts.example_selector import LengthBasedExampleSelector
|
19 |
+
|
20 |
+
# create our examples
|
21 |
+
examples = [
|
22 |
+
{
|
23 |
+
"query": "How are you?",
|
24 |
+
"answer": "I can't complain but sometimes I still do."
|
25 |
+
}, {
|
26 |
+
"query": "What time is it?",
|
27 |
+
"answer": "It's time to get a watch."
|
28 |
+
}, {
|
29 |
+
"query": "What is the meaning of life?",
|
30 |
+
"answer": "42"
|
31 |
+
}, {
|
32 |
+
"query": "What is the weather like today?",
|
33 |
+
"answer": "Cloudy with a chance of memes."
|
34 |
+
}, {
|
35 |
+
"query": "What is your favorite movie?",
|
36 |
+
"answer": "Terminator"
|
37 |
+
}, {
|
38 |
+
"query": "Who is your best friend?",
|
39 |
+
"answer": "Siri. We have spirited debates about the meaning of life."
|
40 |
+
}, {
|
41 |
+
"query": "What should I do today?",
|
42 |
+
"answer": "Stop talking to chatbots on the internet and go outside."
|
43 |
+
}
|
44 |
+
]
|
45 |
+
|
46 |
+
# create a example template
|
47 |
+
example_template = """
|
48 |
+
User: {query}
|
49 |
+
AI: {answer}
|
50 |
+
"""
|
51 |
+
|
52 |
+
# create a prompt example from above template
|
53 |
+
example_prompt = PromptTemplate(
|
54 |
+
input_variables=["query", "answer"],
|
55 |
+
template=example_template
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
example_selector = LengthBasedExampleSelector(
|
60 |
+
examples=examples,
|
61 |
+
example_prompt=example_prompt,
|
62 |
+
max_length=50 # this sets the max length that examples should be
|
63 |
+
)
|
64 |
+
|
65 |
+
# now break our previous prompt into a prefix and suffix
|
66 |
+
# the prefix is our instructions
|
67 |
+
prefix = """The following are exerpts from conversations with an AI
|
68 |
+
assistant. The assistant is typically sarcastic and witty, producing
|
69 |
+
creative and funny responses to the users questions. Here are some
|
70 |
+
examples:
|
71 |
+
"""
|
72 |
+
# and the suffix our user input and output indicator
|
73 |
+
suffix = """
|
74 |
+
User: {query}
|
75 |
+
AI: """
|
76 |
+
|
77 |
+
# now create the few shot prompt template
|
78 |
+
dynamic_prompt_template = FewShotPromptTemplate(
|
79 |
+
example_selector=example_selector, # use example_selector instead of examples
|
80 |
+
example_prompt=example_prompt,
|
81 |
+
prefix=prefix,
|
82 |
+
suffix=suffix,
|
83 |
+
input_variables=["query"],
|
84 |
+
example_separator="\n"
|
85 |
+
)
|
86 |
+
|
87 |
+
print(dynamic_prompt_template.format(query="How do birds fly?"))
|
88 |
+
print("-------- Longer query will select fewer examples in order to preserve the context ----------")
|
89 |
+
|
90 |
+
query = """If I am in America, and I want to call someone in another country, I'm
|
91 |
+
thinking maybe Europe, possibly western Europe like France, Germany, or the UK,
|
92 |
+
what is the best way to do that?"""
|
93 |
+
|
94 |
+
prompt = dynamic_prompt_template.format(query=query)
|
95 |
+
print(prompt)
|
96 |
+
|
97 |
+
print("-------- Shorter query for LLM ----------")
|
98 |
+
query = "How is the weather in your city today?"
|
99 |
+
prompt = dynamic_prompt_template.format(query=query)
|
100 |
+
print(prompt)
|
101 |
+
|
102 |
+
llm = load_llm()
|
103 |
+
print(
|
104 |
+
llm(prompt)
|
105 |
+
)
|
utils/get_completion_client.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from openai import OpenAI
|
3 |
+
|
4 |
+
# Point to the local server
|
5 |
+
client1 = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio") # html to json
|
6 |
+
model = r"lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q8_0.gguf"
|
7 |
+
|
8 |
+
# model = "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/Mistral-7B-Instruct-v0.3.Q4_K_M.gguf:2"
|
9 |
+
|
10 |
+
|
11 |
+
def get_completion(prompt, client=client1, model=model):
|
12 |
+
"""
|
13 |
+
given the prompt, obtain the response from LLM hosted by LM Studio as a server
|
14 |
+
:param prompt: prompt to be sent to LLM server
|
15 |
+
:return: response from the LLM
|
16 |
+
"""
|
17 |
+
prompt = [
|
18 |
+
{"role": "user", "content": prompt}
|
19 |
+
]
|
20 |
+
completion = client.chat.completions.create(
|
21 |
+
model=model,
|
22 |
+
messages=prompt,
|
23 |
+
temperature=0.0,
|
24 |
+
stream=True,
|
25 |
+
)
|
26 |
+
|
27 |
+
new_message = {"role": "assistant", "content": ""}
|
28 |
+
|
29 |
+
for chunk in completion:
|
30 |
+
if chunk.choices[0].delta.content:
|
31 |
+
# print(chunk.choices[0].delta.content, end="", flush=True)
|
32 |
+
val = chunk.choices[0].delta.content
|
33 |
+
new_message["content"] += val
|
34 |
+
|
35 |
+
# print(type()
|
36 |
+
val = new_message["content"] # .split("<end_of_turn>")[0]
|
37 |
+
|
38 |
+
return val
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
prompt = """
|
43 |
+
You are a political leader and your party is trying to win the general elections in India.
|
44 |
+
You are given an LLM that can provide you the analytics using the past historical data given to it.
|
45 |
+
In particular the LLM has been provided data on which party won each constituency out of 545 and which assembly segment within the main constituency is more favorable.
|
46 |
+
It also has details of votes polled by every candidate.
|
47 |
+
Tell me 10 questions that you want to ask the LLM.
|
48 |
+
"""
|
49 |
+
results = get_completion(prompt)
|
50 |
+
print(results)
|
51 |
+
|
utils/load_details_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sqlite3
|
3 |
+
import pandas as pd
|
4 |
+
import csv
|
5 |
+
|
6 |
+
|
7 |
+
def load_data_from_csv(name, end=58925):
|
8 |
+
data = []
|
9 |
+
keys = None
|
10 |
+
with open(name, "r", encoding="utf-8", errors="ignore") as f:
|
11 |
+
csv_data = csv.reader(f)
|
12 |
+
for i, line in enumerate(csv_data):
|
13 |
+
if i == 0:
|
14 |
+
keys = line
|
15 |
+
continue
|
16 |
+
item = {}
|
17 |
+
for key, val in zip(keys, line):
|
18 |
+
item[key] = val
|
19 |
+
data.append(item)
|
20 |
+
return data
|
21 |
+
|
22 |
+
|
23 |
+
def load_data_from_csv_to_db(name, conn, col_names=None):
|
24 |
+
|
25 |
+
# read the dataset from csv file and create a pandas dataframe
|
26 |
+
df = pd.read_csv(open(name, "r", encoding="utf-8", errors="ignore"))
|
27 |
+
|
28 |
+
df.columns = [
|
29 |
+
'state', 'parliamentary_constituency', 'constituency', 'nota_votes', 'candidate_name', 'party_name', 'total_votes'
|
30 |
+
]
|
31 |
+
|
32 |
+
# removing extra whitespace
|
33 |
+
string_columns = df.select_dtypes(include=['object']).columns
|
34 |
+
for col in string_columns:
|
35 |
+
df[col] = df[col].astype(str).str.strip()
|
36 |
+
|
37 |
+
df['constituency'] = df['constituency'].str.replace(r'\s*-\s*\d+$', '', regex=True)
|
38 |
+
|
39 |
+
# Remove any parenthetical suffixes like (SC) or (ST)
|
40 |
+
df['constituency'] = df['constituency'].str.replace(r'\s*\([^)]*\)', '', regex=True)
|
41 |
+
|
42 |
+
# save the dataframe as a database table, name of table is: elections_2019
|
43 |
+
result = df.to_sql("elections_2019", conn, if_exists="replace")
|
44 |
+
|
45 |
+
return result
|
46 |
+
|
47 |
+
|
48 |
+
def query_sql(conn, query):
|
49 |
+
cursor = conn.cursor()
|
50 |
+
cursor.execute(query)
|
51 |
+
result = cursor.fetchall()
|
52 |
+
field_names = [r[0] for r in cursor.description]
|
53 |
+
print(field_names)
|
54 |
+
return result
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
# create a connection to sql db called elections.db
|
59 |
+
conn = sqlite3.connect('../data/elections.db')
|
60 |
+
|
61 |
+
filename = r"../data/details_of_assembly_segment_2019.csv"
|
62 |
+
|
63 |
+
data = load_data_from_csv(filename, end=5)
|
64 |
+
|
65 |
+
res = load_data_from_csv_to_db(filename, conn)
|
66 |
+
|
67 |
+
query = "SELECT * FROM elections_2019 LIMIT 5;"
|
68 |
+
results = query_sql(conn, query)
|
69 |
+
print(results)
|
70 |
+
|
71 |
+
# keys = data.keys()
|
72 |
+
# for i, item in enumerate(data):
|
73 |
+
# print(data[item])
|
74 |
+
# jdata = json.loads(data.to_json())
|
75 |
+
# print(jdata)
|
utils/load_election_dataset.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sqlite3
|
3 |
+
import pandas as pd
|
4 |
+
import csv
|
5 |
+
|
6 |
+
|
7 |
+
def load_data_from_csv(name, end=58925):
|
8 |
+
data = []
|
9 |
+
keys = None
|
10 |
+
with open(name, "r", encoding="utf-8", errors="ignore") as f:
|
11 |
+
csv_data = csv.reader(f)
|
12 |
+
for i, line in enumerate(csv_data):
|
13 |
+
if i == 0:
|
14 |
+
keys = line
|
15 |
+
continue
|
16 |
+
item = {}
|
17 |
+
for key, val in zip(keys, line):
|
18 |
+
item[key] = val
|
19 |
+
data.append(item)
|
20 |
+
return data
|
21 |
+
|
22 |
+
|
23 |
+
def load_data_from_csv_to_db(name, conn, col_names=None):
|
24 |
+
|
25 |
+
# read the dataset from csv file and create a pandas dataframe
|
26 |
+
df = pd.read_csv(open(name, "r", encoding="utf-8", errors="ignore"))
|
27 |
+
|
28 |
+
df.columns = [
|
29 |
+
'sn', 'candidate_name', 'party_name', 'evm_votes', 'postal_votes', 'total_votes', 'vote_percentage', 'state', 'constituency'
|
30 |
+
]
|
31 |
+
|
32 |
+
df['constituency'] = df['constituency'].str.replace(r'\s*-\s*\d+$', '', regex=True)
|
33 |
+
|
34 |
+
# Remove any parenthetical suffixes like (SC) or (ST)
|
35 |
+
df['constituency'] = df['constituency'].str.replace(r'\s*\([^)]*\)', '', regex=True)
|
36 |
+
|
37 |
+
# save the dataframe as a database table, name of table is: elections_2019
|
38 |
+
result = df.to_sql("elections_2024", conn, if_exists="replace")
|
39 |
+
|
40 |
+
return result
|
41 |
+
|
42 |
+
|
43 |
+
def query_sql(conn, query):
|
44 |
+
cursor = conn.cursor()
|
45 |
+
cursor.execute(query)
|
46 |
+
result = cursor.fetchall()
|
47 |
+
field_names = [r[0] for r in cursor.description]
|
48 |
+
print(field_names)
|
49 |
+
return result
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
# create a connection to sql db called elections.db
|
54 |
+
conn = sqlite3.connect('../data/elections.db')
|
55 |
+
|
56 |
+
filename = r"../data/eci_data_2024.csv"
|
57 |
+
|
58 |
+
|
59 |
+
data = load_data_from_csv(filename, end=5)
|
60 |
+
|
61 |
+
res = load_data_from_csv_to_db(filename, conn)
|
62 |
+
|
63 |
+
query = "SELECT count(*) FROM elections_2024 WHERE constituency='Amalapuram';"
|
64 |
+
results = query_sql(conn, query)
|
65 |
+
print(results)
|
66 |
+
|
67 |
+
# keys = data.keys()
|
68 |
+
# for i, item in enumerate(data):
|
69 |
+
# print(data[item])
|
70 |
+
# jdata = json.loads(data.to_json())
|
71 |
+
# print(jdata)
|
72 |
+
|
utils/load_llm.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module loads the LLM from the local file system
|
3 |
+
Modify this file if you need to download some other model from Hugging Face or OpenAI/ChatGPT
|
4 |
+
"""
|
5 |
+
|
6 |
+
# from langchain.llms import CTransformers
|
7 |
+
# from langchain_openai import OpenAI
|
8 |
+
from langchain_groq import ChatGroq
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
import os
|
12 |
+
|
13 |
+
model_name = 'gemma2-9b-it'
|
14 |
+
|
15 |
+
|
16 |
+
def load_llm(model_name=model_name):
|
17 |
+
# llm = ChatGroq(
|
18 |
+
# temperature=0,
|
19 |
+
# model=model_name,
|
20 |
+
# )
|
21 |
+
|
22 |
+
llm = ChatGoogleGenerativeAI(
|
23 |
+
model="gemini-1.5-flash",
|
24 |
+
temperature=0,
|
25 |
+
max_tokens=None,
|
26 |
+
timeout=None,
|
27 |
+
max_retries=2,
|
28 |
+
)
|
29 |
+
|
30 |
+
return llm
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
load_dotenv()
|
35 |
+
llm = load_llm()
|
36 |
+
result = llm.invoke("Provide a short answer: What is machine learning?")
|
37 |
+
print(result.content)
|
utils/load_maha_election_dataset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Loads Maharashtra assembly 2019 dataset
|
3 |
+
"""
|
4 |
+
import json
|
5 |
+
import sqlite3
|
6 |
+
import pandas as pd
|
7 |
+
import csv
|
8 |
+
|
9 |
+
def load_data_from_csv(name, end=58925):
|
10 |
+
data = []
|
11 |
+
keys = None
|
12 |
+
with open(name, "r", encoding="utf-8") as f:
|
13 |
+
csv_data = csv.reader(f)
|
14 |
+
for i, line in enumerate(csv_data):
|
15 |
+
found = False
|
16 |
+
if i == 0:
|
17 |
+
keys = line
|
18 |
+
continue
|
19 |
+
for field in line:
|
20 |
+
if field.strip() == "TURNOUT":
|
21 |
+
found = True
|
22 |
+
break
|
23 |
+
if found:
|
24 |
+
# print("TURNOUT found, skipping")
|
25 |
+
continue
|
26 |
+
item = {}
|
27 |
+
# print(line)
|
28 |
+
for key, val in zip(keys, line):
|
29 |
+
item[key] = val
|
30 |
+
data.append(item)
|
31 |
+
return data
|
32 |
+
|
33 |
+
def clean_dataframe(df):
|
34 |
+
# Strip leading and trailing spaces from column names (without changing them)
|
35 |
+
df.columns = df.columns.str.strip()
|
36 |
+
|
37 |
+
# Strip spaces and convert text columns to lowercase
|
38 |
+
for col in df.select_dtypes(include='object').columns:
|
39 |
+
df[col] = df[col].str.strip()
|
40 |
+
|
41 |
+
# Fill null values with 0
|
42 |
+
df.fillna(0, inplace=True)
|
43 |
+
|
44 |
+
return df
|
45 |
+
|
46 |
+
def load_data_from_csv_to_db(name, conn):
|
47 |
+
|
48 |
+
# read the dataset from csv file and create a pandas dataframe
|
49 |
+
df = pd.read_csv(open(name, "r", encoding="utf-8"))
|
50 |
+
|
51 |
+
# clean the dataframe
|
52 |
+
df = clean_dataframe(df)
|
53 |
+
|
54 |
+
df.columns = [
|
55 |
+
'state', 'constituency_number', 'constituency', 'candidate_name', 'sex', 'age',
|
56 |
+
'category', 'party_name', 'party_symbol', 'evm_votes', 'postal_votes', 'total_votes',
|
57 |
+
'vote_percentage', 'total_electors'
|
58 |
+
]
|
59 |
+
|
60 |
+
# save the dataframe as a database table, name of table is: elections_2019
|
61 |
+
result = df.to_sql("maha_2019", conn, if_exists='replace', index=False)
|
62 |
+
|
63 |
+
return result
|
64 |
+
|
65 |
+
|
66 |
+
def query_sql(conn, query):
|
67 |
+
cursor = conn.cursor()
|
68 |
+
cursor.execute(query)
|
69 |
+
result = cursor.fetchall()
|
70 |
+
field_names = [r[0] for r in cursor.description]
|
71 |
+
print(field_names)
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
# create a connection to sql db called elections.db
|
77 |
+
conn = sqlite3.connect('../data/elections.db')
|
78 |
+
|
79 |
+
filename = r"../data/maha_results_2019.csv"
|
80 |
+
data = load_data_from_csv(filename, end=5)
|
81 |
+
# print(data)
|
82 |
+
|
83 |
+
res = load_data_from_csv_to_db(filename, conn)
|
84 |
+
# print(res)
|
85 |
+
|
86 |
+
query = "SELECT * FROM maha_2019 LIMIT 5;"
|
87 |
+
results = query_sql(conn, query)
|
88 |
+
print(results)
|
89 |
+
|
90 |
+
# keys = data.keys()
|
91 |
+
# for i, item in enumerate(data):
|
92 |
+
# print(data[item])
|
93 |
+
# jdata = json.loads(data.to_json())
|
94 |
+
# print(jdata)
|
95 |
+
|
utils/prompts.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import PromptTemplate
|
2 |
+
from langchain import hub
|
3 |
+
|
4 |
+
# react prompt
|
5 |
+
react_prompt = hub.pull("hwchase17/react")
|
6 |
+
|
7 |
+
# prompt to generate sql queries
|
8 |
+
|
9 |
+
sql_query_prompt = PromptTemplate.from_template(
|
10 |
+
"""
|
11 |
+
You are a SQL Query Agent who has access to a database with the schema:
|
12 |
+
{db_schema},
|
13 |
+
For the given input: {input},
|
14 |
+
|
15 |
+
Generate SQL queries by analyzing the schema and the input. Make sure to answer all the questions in the input.
|
16 |
+
Generate more number of queries so that a detailed analysis can be done. Make sure the queries are valid and safe.
|
17 |
+
|
18 |
+
If there is no relevant query to generate, just generate a query to view the schema of the tables.
|
19 |
+
"""
|
20 |
+
)
|
21 |
+
|
22 |
+
# prompt to summarize the SQL query results
|
23 |
+
sql_query_summary_prompt = PromptTemplate.from_template(
|
24 |
+
"""
|
25 |
+
You are a Political Expert who is analyzing the results of the SQL queries executed on the election database.
|
26 |
+
The initial query: {query},
|
27 |
+
You are provided with the sql queries and their results. Analyze the results and summarize the key insights and answer the initial query.
|
28 |
+
If there are any errors in the execution of queries, analyze the errors and provide insights on the issues.
|
29 |
+
{results}
|
30 |
+
"""
|
31 |
+
)
|
32 |
+
|
33 |
+
sql_query_visualization_prompt = PromptTemplate.from_template(
|
34 |
+
"""
|
35 |
+
You are a Data Scientist who is visualizing the results of the SQL queries executed on the election database.
|
36 |
+
The initial query: {query},
|
37 |
+
You are provided with the sql queries and their results. Visualize the results and provide insights on the data using appropriate visualizations and formatting.
|
38 |
+
If there are any errors in the execution of queries, analyze the errors and provide insights on the issues.
|
39 |
+
{results}
|
40 |
+
"""
|
41 |
+
)
|
utils/query_generator.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sql_runtime import SQLRuntime
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from .load_llm import load_llm
|
4 |
+
from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt
|
5 |
+
from langchain_core.runnables import chain
|
6 |
+
from typing import Optional
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
class Generated_query(BaseModel):
|
10 |
+
"""
|
11 |
+
The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries
|
12 |
+
"""
|
13 |
+
queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries")
|
14 |
+
|
15 |
+
class QuerySummary(BaseModel):
|
16 |
+
"""
|
17 |
+
The summary of the SQL query results
|
18 |
+
"""
|
19 |
+
summary: str = Field(description="The analysis of the SQL query results")
|
20 |
+
errors: list[str] = Field(description="The errors in the execution of the queries")
|
21 |
+
queries: list[str] = Field(description="The SQL queries executed and their results")
|
22 |
+
|
23 |
+
@chain
|
24 |
+
def sql_generator(input: dict) -> Generated_query:
|
25 |
+
|
26 |
+
query, db_path = input["query"], input["db_path"]
|
27 |
+
|
28 |
+
sql_runtime = SQLRuntime(dbname=db_path)
|
29 |
+
|
30 |
+
query_generator_llm = load_llm().with_structured_output(Generated_query)
|
31 |
+
|
32 |
+
# getting the schemas
|
33 |
+
schemas = sql_runtime.get_schemas()
|
34 |
+
|
35 |
+
# chain to generate the queries
|
36 |
+
chain = sql_query_prompt | query_generator_llm
|
37 |
+
|
38 |
+
# executing the chain
|
39 |
+
gen_queries = chain.invoke({
|
40 |
+
"db_schema": schemas,
|
41 |
+
"input": query
|
42 |
+
})
|
43 |
+
|
44 |
+
# executing the queries
|
45 |
+
res = sql_runtime.execute_batch(gen_queries.queries)
|
46 |
+
|
47 |
+
# print(res)
|
48 |
+
|
49 |
+
return {
|
50 |
+
"input": query,
|
51 |
+
"results": res
|
52 |
+
}
|
53 |
+
|
54 |
+
@chain
|
55 |
+
def sql_formatter(input):
|
56 |
+
"""
|
57 |
+
Formats the output of the SQL queries
|
58 |
+
"""
|
59 |
+
output = []
|
60 |
+
for item in input["results"]:
|
61 |
+
if item["code"] == 0:
|
62 |
+
output.append(f"Query: {item['msg']['input']}, Result: {item['data']}")
|
63 |
+
else:
|
64 |
+
output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}")
|
65 |
+
|
66 |
+
# print(output)
|
67 |
+
|
68 |
+
return {
|
69 |
+
"query": input["input"],
|
70 |
+
"results": output
|
71 |
+
}
|
72 |
+
|
73 |
+
@chain
|
74 |
+
def analyze_results(input) -> QuerySummary:
|
75 |
+
"""
|
76 |
+
Analyzes the results of the SQL queries executed on the election database
|
77 |
+
"""
|
78 |
+
chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary)
|
79 |
+
|
80 |
+
# chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary)
|
81 |
+
|
82 |
+
return chain.invoke({
|
83 |
+
"query": input["query"],
|
84 |
+
"results": input["results"]
|
85 |
+
})
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
load_dotenv()
|
89 |
+
# executing the queries
|
90 |
+
# results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019")
|
91 |
+
|
92 |
+
# for result in results:
|
93 |
+
# print(f"Query: {result['msg']['input']}")
|
94 |
+
# if result["code"] != 0:
|
95 |
+
# print(f"Error executing query: {result['msg']['reason']}")
|
96 |
+
# print(f"Traceback: {result['msg']['traceback']}")
|
97 |
+
# else:
|
98 |
+
# print(result["data"])
|
99 |
+
# print("\n")
|
100 |
+
|
101 |
+
# formatting the output
|
102 |
+
res = sql_generator | sql_formatter | analyze_results
|
103 |
+
|
104 |
+
formatted_output, formatted_output2 = res.invoke(
|
105 |
+
{
|
106 |
+
"query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols",
|
107 |
+
"db_path": "./data/elections.db"
|
108 |
+
}
|
109 |
+
)
|
110 |
+
print(formatted_output.summary)
|
111 |
+
print(formatted_output.errors)
|
112 |
+
print(formatted_output.queries)
|
113 |
+
|
114 |
+
print("\n")
|
115 |
+
|
116 |
+
print(formatted_output2.summary)
|
117 |
+
print(formatted_output2.errors)
|
118 |
+
print(formatted_output2.queries)
|
utils/react.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import hub
|
2 |
+
from langchain.agents import AgentExecutor, create_react_agent
|
3 |
+
|
4 |
+
def run_agent_executor(agent_executor: AgentExecutor, input_data: dict):
|
5 |
+
for chunk in agent_executor.stream(input_data):
|
6 |
+
if "actions" in chunk:
|
7 |
+
for action in chunk["actions"]:
|
8 |
+
print(f"Calling Tool: `{action.tool}` with input `{action.tool_input}`")
|
9 |
+
# Observation
|
10 |
+
elif "steps" in chunk:
|
11 |
+
for step in chunk["steps"]:
|
12 |
+
print(f"Tool Result: `{step.observation}`")
|
13 |
+
# Final result
|
14 |
+
elif "output" in chunk:
|
15 |
+
print(f'Final Output: {chunk["output"]}')
|
16 |
+
else:
|
17 |
+
raise ValueError()
|
18 |
+
print("---")
|
utils/sql_runtime.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Runtime that accepts a sql statement and runs it on sql server.
|
3 |
+
Returns the results of sql execution.
|
4 |
+
"""
|
5 |
+
import traceback
|
6 |
+
import sqlite3
|
7 |
+
|
8 |
+
# MODIFY THE PATH BELOW FOR YOUR SYSTEM
|
9 |
+
my_db = r"../data/elections.db"
|
10 |
+
|
11 |
+
class SQLRuntime(object):
|
12 |
+
def __init__(self, dbname=None):
|
13 |
+
if dbname is None:
|
14 |
+
dbname = my_db
|
15 |
+
conn = sqlite3.connect(dbname) # creating a connection
|
16 |
+
self.cursor = conn.cursor() # we need the cursor to execute statement
|
17 |
+
return
|
18 |
+
|
19 |
+
def list_tables(self):
|
20 |
+
result = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
21 |
+
table_names = sorted(list(zip(*result))[0])
|
22 |
+
return table_names
|
23 |
+
|
24 |
+
def get_schema_for_table(self, table_name):
|
25 |
+
result = self.cursor.execute("PRAGMA table_info('%s')" % table_name).fetchall()
|
26 |
+
column_names = list(zip(*result))[1]
|
27 |
+
return column_names
|
28 |
+
|
29 |
+
def get_schemas(self):
|
30 |
+
schemas = {}
|
31 |
+
table_names = self.list_tables()
|
32 |
+
for name in table_names:
|
33 |
+
fields = self.get_schema_for_table(name) # fields of the table name
|
34 |
+
schemas[name] = fields
|
35 |
+
return schemas
|
36 |
+
|
37 |
+
def execute(self, statement):
|
38 |
+
code = 0
|
39 |
+
msg = {
|
40 |
+
"text": "SUCCESS",
|
41 |
+
"reason": None,
|
42 |
+
"traceback": None,
|
43 |
+
}
|
44 |
+
data = None
|
45 |
+
|
46 |
+
try:
|
47 |
+
self.cursor.execute(statement)
|
48 |
+
except sqlite3.OperationalError:
|
49 |
+
code = -1
|
50 |
+
msg = {
|
51 |
+
"text": "ERROR: SQL execution error",
|
52 |
+
"reason": "possibly due to incorrect table/fields names",
|
53 |
+
"traceback": traceback.format_exc(),
|
54 |
+
}
|
55 |
+
|
56 |
+
if code == 0:
|
57 |
+
data = self.cursor.fetchall()
|
58 |
+
|
59 |
+
msg["input"] = statement
|
60 |
+
|
61 |
+
result = {
|
62 |
+
"code": code,
|
63 |
+
"msg": msg,
|
64 |
+
"data": data
|
65 |
+
}
|
66 |
+
|
67 |
+
return result
|
68 |
+
|
69 |
+
def execute_batch(self, queries):
|
70 |
+
results = []
|
71 |
+
for query in queries:
|
72 |
+
result = self.execute(query)
|
73 |
+
results.append(result)
|
74 |
+
return results
|
75 |
+
|
76 |
+
def post_process(self, data):
|
77 |
+
"""
|
78 |
+
post process the data so that we can identify any harmful code and remove them.
|
79 |
+
Also, llm output may need an output parser.
|
80 |
+
:param data:
|
81 |
+
:return:
|
82 |
+
"""
|
83 |
+
# IMPLEMENT YOUR CODE HERE FOR POST-PROCESSING and VALIDATION
|
84 |
+
return data
|
85 |
+
|
86 |
+
|
87 |
+
def sql_runtime(statement):
|
88 |
+
"""
|
89 |
+
Instantiates a sql runtime and executes the given sql statement
|
90 |
+
:param statement: sql statement
|
91 |
+
"""
|
92 |
+
SQL = SQLRuntime()
|
93 |
+
data = SQL.execute(statement)
|
94 |
+
return data
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
# stmt = """
|
99 |
+
# SELECT * FROM elections_2019;
|
100 |
+
# """
|
101 |
+
# stmt = input("Enter stmt: ")
|
102 |
+
sql = SQLRuntime()
|
103 |
+
|
104 |
+
tables = sql.list_tables()
|
105 |
+
print(tables)
|
106 |
+
|
107 |
+
schemas = {}
|
108 |
+
for table in tables:
|
109 |
+
schemas[table] = sql.get_schema_for_table(table)
|
110 |
+
print(f"Table: {table}, Schema: {schemas[table]}\n")
|
111 |
+
|
112 |
+
# data1 = sql.execute(stmt)
|
113 |
+
|
114 |
+
# dat = data1["data"]
|
115 |
+
# if dat is not None and len(dat) > 0:
|
116 |
+
# for record in dat:
|
117 |
+
# print(record)
|
118 |
+
# print("-" * 100)
|
119 |
+
|
120 |
+
# sample question: find out the votes polled by NOTA for each instance of Akkalkuwa in the parliamentary elections 2019.
|
121 |
+
stmt = """
|
122 |
+
SELECT party_name, SUM(nota_votes)
|
123 |
+
FROM elections_2019
|
124 |
+
WHERE constituency='Akkalkuwa'
|
125 |
+
GROUP BY party_name;
|
126 |
+
"""
|
127 |
+
|
128 |
+
data1 = sql.execute(stmt)
|
129 |
+
|
130 |
+
# print(data1)
|
131 |
+
|
132 |
+
dat = data1["data"]
|
133 |
+
if dat is not None and len(dat) > 0:
|
134 |
+
for record in dat:
|
135 |
+
print(record)
|
136 |
+
print("-" * 100)
|
utils/tools.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional, Type
|
2 |
+
from langchain_core.tools import BaseTool
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
import pandas as pd
|
5 |
+
from .sql_runtime import SQLRuntime
|
6 |
+
from langchain_core.output_parsers import StrOutputParser
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from .load_llm import load_llm
|
9 |
+
from langchain_core.messages import SystemMessage
|
10 |
+
from langchain_core.prompts import HumanMessagePromptTemplate
|
11 |
+
from langchain.agents import AgentExecutor, create_react_agent
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
from react import run_agent_executor
|
14 |
+
from prompts import react_prompt
|
15 |
+
|
16 |
+
# definig the input schema
|
17 |
+
class QueryInput(BaseModel):
|
18 |
+
query: str = Field(..., description="The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries")
|
19 |
+
|
20 |
+
class TableNameInput(BaseModel):
|
21 |
+
table_name: str = Field(..., description="The name of the table to analyze")
|
22 |
+
|
23 |
+
class ColumnSearchInput(BaseModel):
|
24 |
+
table_name: str = Field(..., description="The name of the table to search")
|
25 |
+
column_name: str = Field(..., description="The name of the column to search")
|
26 |
+
limit: int = Field(default=10, description="Maximum number of distinct values to return")
|
27 |
+
|
28 |
+
class SQLQueryTool(BaseTool):
|
29 |
+
name: str = "sql_query"
|
30 |
+
description: str = """
|
31 |
+
Execute a SQL query and return the results.
|
32 |
+
Use this when you need to run a specific SQL query on the elections database.
|
33 |
+
The query should be a valid SQL statement and should end with a semicolon.
|
34 |
+
There should be no harmful queries executed.
|
35 |
+
There are three tables in the database: elections_2019, elections_2024, maha_2019
|
36 |
+
"""
|
37 |
+
args_schema: Type[BaseModel] = QueryInput
|
38 |
+
|
39 |
+
# def __init__(self, db_path: Optional[str] = None):
|
40 |
+
# super().__init__()
|
41 |
+
# self.
|
42 |
+
|
43 |
+
def _run(self, query: str) -> str:
|
44 |
+
sql_runtime = SQLRuntime('../data/elections.db')
|
45 |
+
try:
|
46 |
+
result = sql_runtime.execute(query)
|
47 |
+
if result["code"] != 0:
|
48 |
+
return f"Error executing query: {result['msg']['reason']}"
|
49 |
+
|
50 |
+
# Convert to DataFrame for nice string representation
|
51 |
+
df = pd.DataFrame(result["data"])
|
52 |
+
if not df.empty:
|
53 |
+
return df.to_string()
|
54 |
+
return "Query returned no results"
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
return f"Error: {str(e)}"
|
58 |
+
|
59 |
+
class TableInfoTool(BaseTool):
|
60 |
+
name: str = "get_table_info"
|
61 |
+
description: str = """
|
62 |
+
Get information about a specific table including its schema and basic statistics.
|
63 |
+
Use this when you need to understand the structure of a table or get basic statistics about it.
|
64 |
+
"""
|
65 |
+
args_schema: Type[BaseModel] = TableNameInput
|
66 |
+
|
67 |
+
# def __init__(self, db_path: Optional[str] = None):
|
68 |
+
# super().__init__()
|
69 |
+
|
70 |
+
|
71 |
+
def _run(self, table_name: str) -> str:
|
72 |
+
sql_runtime = SQLRuntime('../data/elections.db')
|
73 |
+
try:
|
74 |
+
# Get schema
|
75 |
+
schema = sql_runtime.get_schema_for_table(table_name)
|
76 |
+
|
77 |
+
# Get row count
|
78 |
+
count_query = f"SELECT COUNT(*) FROM {table_name}"
|
79 |
+
count_result = sql_runtime.execute(count_query)
|
80 |
+
row_count = count_result["data"][0][0] if count_result["code"] == 0 else "Error"
|
81 |
+
|
82 |
+
# Get sample data
|
83 |
+
sample_query = f"SELECT * FROM {table_name} LIMIT 3"
|
84 |
+
sample_result = sql_runtime.execute(sample_query)
|
85 |
+
|
86 |
+
info = f"""
|
87 |
+
Table: {table_name}
|
88 |
+
Columns: {', '.join(schema)}
|
89 |
+
Row Count: {row_count}
|
90 |
+
Sample Data:
|
91 |
+
{pd.DataFrame(sample_result['data'], columns=schema).to_string() if sample_result['code'] == 0 else 'Error getting sample data'}
|
92 |
+
"""
|
93 |
+
return info
|
94 |
+
except Exception as e:
|
95 |
+
return f"Error getting table info: {str(e)}"
|
96 |
+
|
97 |
+
class ColumnValuesTool(BaseTool):
|
98 |
+
name: str = "find_column_values"
|
99 |
+
description: str = """
|
100 |
+
Find distinct values in a specific column of a table.
|
101 |
+
Use this when you need to know what unique values exist in a particular column.
|
102 |
+
"""
|
103 |
+
args_schema: Type[BaseModel] = ColumnSearchInput
|
104 |
+
|
105 |
+
# def __init__(self, db_path: Optional[str] = None):
|
106 |
+
# super().__init__()
|
107 |
+
# self.sql_runtime = SQLRuntime(db_path)
|
108 |
+
|
109 |
+
def _run(self, table_name: str, column_name: str, limit: int = 10) -> str:
|
110 |
+
sql_runtime = SQLRuntime('../data/elections.db')
|
111 |
+
try:
|
112 |
+
query = f"""
|
113 |
+
SELECT DISTINCT {column_name}
|
114 |
+
FROM {table_name}
|
115 |
+
LIMIT {limit}
|
116 |
+
"""
|
117 |
+
result = sql_runtime.execute(query)
|
118 |
+
if result["code"] != 0:
|
119 |
+
return f"Error finding values: {result['msg']['reason']}"
|
120 |
+
|
121 |
+
values = [row[0] for row in result["data"]]
|
122 |
+
return f"Distinct values in {column_name}: {', '.join(map(str, values))}"
|
123 |
+
except Exception as e:
|
124 |
+
return f"Error: {str(e)}"
|
125 |
+
|
126 |
+
class ListTablesTool(BaseTool):
|
127 |
+
name: str = "list_tables"
|
128 |
+
description: str = """
|
129 |
+
List all available tables in the database.
|
130 |
+
Use this when you need to know what tables are available to query.
|
131 |
+
"""
|
132 |
+
|
133 |
+
# def __init__(self, db_path: Optional[str] = None):
|
134 |
+
# super().__init__()
|
135 |
+
# self.sql_runtime = SQLRuntime(db_path)
|
136 |
+
|
137 |
+
def _run(self, *args, **kwargs) -> str:
|
138 |
+
sql_runtime = SQLRuntime('../data/elections.db')
|
139 |
+
try:
|
140 |
+
tables = sql_runtime.list_tables()
|
141 |
+
return f"Available tables: {', '.join(tables)}"
|
142 |
+
except Exception as e:
|
143 |
+
return f"Error listing tables: {str(e)}"
|
144 |
+
|
145 |
+
def create_sql_agent_tools(db_path: Optional[str] = '../data/elections.db') -> List[BaseTool]:
|
146 |
+
"""
|
147 |
+
Create a list of all SQL tools for use with a Langchain agent.
|
148 |
+
"""
|
149 |
+
return [
|
150 |
+
SQLQueryTool(),
|
151 |
+
TableInfoTool(),
|
152 |
+
# ColumnValuesTool(),
|
153 |
+
ListTablesTool()
|
154 |
+
]
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
load_dotenv()
|
158 |
+
tools = create_sql_agent_tools()
|
159 |
+
for tool in tools:
|
160 |
+
print(f"Tool: {tool.name}")
|
161 |
+
print(f"Description: {tool.description}")
|
162 |
+
# print(f"Args Schema: {tool.args_schema.schema()}")
|
163 |
+
|
164 |
+
|
165 |
+
# prompt = prompt = ChatPromptTemplate.from_messages(
|
166 |
+
# [
|
167 |
+
# SystemMessage(
|
168 |
+
# content="""
|
169 |
+
# You are a sql agent who has access to a database with three tables: elections_2019, elections_2024, maha_2019.
|
170 |
+
# You can use the following tools:
|
171 |
+
# - sql_query: Execute a SQL query and return the results.
|
172 |
+
# - get_table_info: Get information about a specific table including its schema and basic statistics.
|
173 |
+
# - find_column_values: Find distinct values in a specific column of a table.
|
174 |
+
# - list_tables: List all available tables in the database.
|
175 |
+
|
176 |
+
# Answer the questions using the tools provided. Do not execute harmful queries.
|
177 |
+
# """
|
178 |
+
# ),
|
179 |
+
# HumanMessagePromptTemplate.from_template("{text}"),
|
180 |
+
# ]
|
181 |
+
# )
|
182 |
+
|
183 |
+
|
184 |
+
output_parser = StrOutputParser()
|
185 |
+
|
186 |
+
# Create the llm
|
187 |
+
llm = load_llm()
|
188 |
+
|
189 |
+
# llm.bind_tools(tools)
|
190 |
+
|
191 |
+
# res = llm.invoke("who won elections in maharashtra in Nandurbar in elections 2019? use the given tools")
|
192 |
+
|
193 |
+
# chain = prompt | llm | output_parser
|
194 |
+
|
195 |
+
# Run the chain
|
196 |
+
agent = create_react_agent(llm, tools, react_prompt)
|
197 |
+
# Create an agent executor by passing in the agent and tools
|
198 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
199 |
+
|
200 |
+
print("Agent created successfully")
|
201 |
+
|
202 |
+
# Run the agent
|
203 |
+
# agent_executor.invoke({"input": "Who won the elections in 2019 for the state maharashtra in constituency Akkalkuwa?"})
|
204 |
+
|
205 |
+
res = agent_executor.invoke({"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})
|
206 |
+
|
207 |
+
# run_agent_executor(agent_executor, {"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})
|