wip
Browse files
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -24,29 +24,7 @@ vn.connect_to_sqlite(db_vanna_path)
|
|
24 |
|
25 |
llm = get_llm(provider="openai")
|
26 |
|
27 |
-
# def ask_vanna(query):
|
28 |
-
# location = detect_location_with_openai(OPENAI_API_KEY, query)
|
29 |
-
# if location:
|
30 |
-
# coords = loc2coords(location)
|
31 |
-
# user_input = query.replace(location, f"lat, long : {coords}")
|
32 |
-
# answer = vn.ask(user_input, print_results=False, allow_llm_to_see_data=True)
|
33 |
-
# table = detectTable(answer[0])
|
34 |
-
# coords2 = nearestNeighbourSQL(db_vanna_path, coords, table[0])
|
35 |
|
36 |
-
# query = answer[0].replace(f"{coords[0]}", f"{coords2[0]}")
|
37 |
-
# sql_query = query.replace(f"{coords[1]}", f"{coords2[1]}")
|
38 |
-
|
39 |
-
# db = sqlite3.connect(db_vanna_path)
|
40 |
-
# result = db.cursor().execute(sql_query).fetchall()
|
41 |
-
# print(result)
|
42 |
-
# df = pd.DataFrame(result, columns=answer[1].columns)
|
43 |
-
|
44 |
-
# else:
|
45 |
-
# answer = vn.ask(query, visualize=True, print_results=False, allow_llm_to_see_data=True)
|
46 |
-
# sql_query = answer[0]
|
47 |
-
# df = answer[1]
|
48 |
-
|
49 |
-
# return (sql_query, df)
|
50 |
def replace_coordonates(coords, sql_query, coords_tables):
|
51 |
n = sql_query.count(str(coords[0]))
|
52 |
sql_query_new_coords = sql_query
|
@@ -57,34 +35,40 @@ def replace_coordonates(coords, sql_query, coords_tables):
|
|
57 |
return sql_query_new_coords
|
58 |
|
59 |
def ask_vanna(query):
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
88 |
empty_df = pd.DataFrame()
|
89 |
empty_fig = {}
|
90 |
-
return empty_df, empty_fig
|
|
|
24 |
|
25 |
llm = get_llm(provider="openai")
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def replace_coordonates(coords, sql_query, coords_tables):
|
29 |
n = sql_query.count(str(coords[0]))
|
30 |
sql_query_new_coords = sql_query
|
|
|
35 |
return sql_query_new_coords
|
36 |
|
37 |
def ask_vanna(query):
|
38 |
+
try :
|
39 |
+
location = detect_location_with_openai(OPENAI_API_KEY, query)
|
40 |
+
if location:
|
41 |
|
42 |
+
coords = loc2coords(location)
|
43 |
+
user_input = query.replace(location, f"lat, long : {coords}")
|
44 |
+
sql_query, result_dataframe, figure = vn.ask(user_input, print_results=False, allow_llm_to_see_data=True)
|
45 |
+
table = detectTable(sql_query)
|
46 |
+
coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, table[i]) for i in range(len(table))]
|
47 |
+
sql_query_new_coords = replace_coordonates(coords, sql_query, coords_tables)
|
48 |
+
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\n").content
|
49 |
+
print("execute sql query : ", sql_with_table_names)
|
50 |
+
db = sqlite3.connect(db_vanna_path)
|
51 |
+
result = db.cursor().execute(sql_query_new_coords).fetchall()
|
52 |
+
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query_new_coords}").content
|
53 |
+
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
54 |
+
print("column list : ",columns_list)
|
55 |
+
df = pd.DataFrame(result, columns=columns_list)
|
56 |
+
|
57 |
+
plotly_code = vn.generate_plotly_code(
|
58 |
+
question="query",
|
59 |
+
sql="sql_with_table_names",
|
60 |
+
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
|
61 |
+
)
|
62 |
|
63 |
+
fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
|
64 |
|
65 |
+
return df, fig
|
66 |
+
else :
|
67 |
+
empty_df = pd.DataFrame()
|
68 |
+
empty_fig = {}
|
69 |
+
return empty_df, empty_fig
|
70 |
+
except Exception as e:
|
71 |
+
print(f"Error: {e}")
|
72 |
empty_df = pd.DataFrame()
|
73 |
empty_fig = {}
|
74 |
+
return empty_df, empty_fig
|
climateqa/engine/talk_to_data/test_vanna.ipynb
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import sys\n",
|
10 |
+
"import os\n",
|
11 |
+
"sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))\n",
|
12 |
+
"\n",
|
13 |
+
"%load_ext autoreload\n",
|
14 |
+
"%autoreload 2\n",
|
15 |
+
"\n",
|
16 |
+
"from main import ask_vanna\n",
|
17 |
+
"import sqlite3\n",
|
18 |
+
"import os\n",
|
19 |
+
"import pandas as pd"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"table_names_list = [\n",
|
29 |
+
" \"Frequency_of_rainy_days_index\",\n",
|
30 |
+
" \"Winter_precipitation_total\",\n",
|
31 |
+
" \"Summer_precipitation_total\",\n",
|
32 |
+
" \"Annual_precipitation_total\",\n",
|
33 |
+
" \"Remarkable_daily_precipitation_total_(Q99)\",\n",
|
34 |
+
" \"Frequency_of_remarkable_daily_precipitation\",\n",
|
35 |
+
" \"Extreme_precipitation_intensity\",\n",
|
36 |
+
" \"Mean_winter_temperature\",\n",
|
37 |
+
" \"Mean_summer_temperature\",\n",
|
38 |
+
" \"Number_of_tropical_nights\",\n",
|
39 |
+
" \"Maximum_summer_temperature\",\n",
|
40 |
+
" \"Number_of_days_with_Tx_above_30C\",\n",
|
41 |
+
" \"Number_of_days_with_Tx_above_35C\",\n",
|
42 |
+
" \"Drought_index\"\n",
|
43 |
+
"]"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"from climateqa.engine.llm import get_llm\n",
|
53 |
+
"\n",
|
54 |
+
"llm = get_llm(provider=\"openai\")\n",
|
55 |
+
"user_question = \"Quel sera la température à Marseille dans les prochaines années ?\"\n",
|
56 |
+
"prompt = f\"You are helping to build a sql query to retrieve relevant data for a user question. The different tables are {table_names_list}. The user question is {user_question}. Write the relevant table to query. Answer only the table name.\"\n",
|
57 |
+
"table_name = llm.invoke(prompt).content\n",
|
58 |
+
"# llm.invoke(f\"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\\n\").content\n"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": null,
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [],
|
66 |
+
"source": [
|
67 |
+
"docs = {\"Mean_summer_temperature\": {\n",
|
68 |
+
" \"description\": (\n",
|
69 |
+
" \"The Mean summer temperature table contains information on the average summer temperature in the past and the future. \"\n",
|
70 |
+
" \"The variables are as follows:\\n\"\n",
|
71 |
+
" \"- 'y' and 'x': Lambert Paris II coordinates for the location.\\n\"\n",
|
72 |
+
" \"- year: Year of the observation.\\n\"\n",
|
73 |
+
" \"- month : Month of the observation.\\n\"\n",
|
74 |
+
" \"- day: Day of the observation.\\n\"\n",
|
75 |
+
" \"- 'LambertParisII': Indicates that the x and y coordinates are in Lambert Paris II projection.\\n\"\n",
|
76 |
+
" \"- 'lat' and 'lon': Latitude and longitude of the location.\\n\"\n",
|
77 |
+
" \"- 'TMm': Average summer temperature.\\n\"\n",
|
78 |
+
" ),\n",
|
79 |
+
" \"sql_query\": \"\"\"\n",
|
80 |
+
" CREATE TABLE Mean_summer_temperature (\n",
|
81 |
+
" y FLOAT,\n",
|
82 |
+
" x FLOAT,\n",
|
83 |
+
" year INT,\n",
|
84 |
+
" month INT, \n",
|
85 |
+
" day INT,\n",
|
86 |
+
" LambertParisII VARCHAR(255),\n",
|
87 |
+
" lat FLOAT,\n",
|
88 |
+
" lon FLOAT,\n",
|
89 |
+
" TMm FLOAT, -- Température moyenne en été\n",
|
90 |
+
" );\n",
|
91 |
+
" \"\"\"}\n",
|
92 |
+
"}"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": null,
|
98 |
+
"metadata": {},
|
99 |
+
"outputs": [],
|
100 |
+
"source": [
|
101 |
+
"from climateqa.engine.talk_to_data.utils import loc2coords\n",
|
102 |
+
"location = \"Marseille\"\n",
|
103 |
+
"coords = loc2coords(location)\n",
|
104 |
+
"user_input = user_question.replace(location, f\"lat, long : {coords}\")\n"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": null,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"initial_prompt = f\"You are a mysql expert. \" + \\\n",
|
114 |
+
" \"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \"\n",
|
115 |
+
"initial_prompt += f\"\\n===Tables \\n + {docs[table_name]['sql_query']}\"\n",
|
116 |
+
"initial_prompt += f\"\\n===Additional Context \\n\\n {docs[table_name]['description']}\"\n",
|
117 |
+
"initial_prompt += (\n",
|
118 |
+
" \"===Response Guidelines \\n\"\n",
|
119 |
+
" \"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \\n\"\n",
|
120 |
+
" \"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \\n\"\n",
|
121 |
+
" \"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \\n\"\n",
|
122 |
+
" \"4. Please use the most relevant table(s). \\n\"\n",
|
123 |
+
" \"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \\n\"\n",
|
124 |
+
" f\"6. Ensure that the output SQL is mysql-compliant and executable, and free of syntax errors. \\n\"\n",
|
125 |
+
" )\n",
|
126 |
+
"initial_prompt += f\"\\n===Question \\n\\n {user_input}\"\n"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": null,
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"sql_query = llm.invoke(initial_prompt).content\n",
|
136 |
+
"sql_query"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "markdown",
|
141 |
+
"metadata": {},
|
142 |
+
"source": [
|
143 |
+
"# Vanna ask"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": null,
|
149 |
+
"metadata": {},
|
150 |
+
"outputs": [],
|
151 |
+
"source": [
|
152 |
+
"from climateqa.engine.llm import get_llm\n",
|
153 |
+
"import ast\n",
|
154 |
+
"\n",
|
155 |
+
"llm = get_llm(provider=\"openai\")\n",
|
156 |
+
"columns = llm.invoke(f\"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query SELECT 'Mean_winter_temperature' AS source_table, year, month, day, TMm FROM Mean_winter_temperature WHERE lat = 43.166954040527344 AND lon = 5.430534839630127;\").content\n",
|
157 |
+
"columns_list = ast.literal_eval(columns.strip(\"```python\\n\").strip())\n"
|
158 |
+
]
|
159 |
+
}
|
160 |
+
],
|
161 |
+
"metadata": {
|
162 |
+
"kernelspec": {
|
163 |
+
"display_name": "climateqa",
|
164 |
+
"language": "python",
|
165 |
+
"name": "python3"
|
166 |
+
},
|
167 |
+
"language_info": {
|
168 |
+
"name": "python",
|
169 |
+
"version": "3.11.9"
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"nbformat": 4,
|
173 |
+
"nbformat_minor": 2
|
174 |
+
}
|