timeki commited on
Commit
bbfd1ce
·
1 Parent(s): 7287f1d
climateqa/engine/talk_to_data/main.py CHANGED
@@ -24,29 +24,7 @@ vn.connect_to_sqlite(db_vanna_path)
24
 
25
  llm = get_llm(provider="openai")
26
 
27
- # def ask_vanna(query):
28
- # location = detect_location_with_openai(OPENAI_API_KEY, query)
29
- # if location:
30
- # coords = loc2coords(location)
31
- # user_input = query.replace(location, f"lat, long : {coords}")
32
- # answer = vn.ask(user_input, print_results=False, allow_llm_to_see_data=True)
33
- # table = detectTable(answer[0])
34
- # coords2 = nearestNeighbourSQL(db_vanna_path, coords, table[0])
35
 
36
- # query = answer[0].replace(f"{coords[0]}", f"{coords2[0]}")
37
- # sql_query = query.replace(f"{coords[1]}", f"{coords2[1]}")
38
-
39
- # db = sqlite3.connect(db_vanna_path)
40
- # result = db.cursor().execute(sql_query).fetchall()
41
- # print(result)
42
- # df = pd.DataFrame(result, columns=answer[1].columns)
43
-
44
- # else:
45
- # answer = vn.ask(query, visualize=True, print_results=False, allow_llm_to_see_data=True)
46
- # sql_query = answer[0]
47
- # df = answer[1]
48
-
49
- # return (sql_query, df)
50
  def replace_coordonates(coords, sql_query, coords_tables):
51
  n = sql_query.count(str(coords[0]))
52
  sql_query_new_coords = sql_query
@@ -57,34 +35,40 @@ def replace_coordonates(coords, sql_query, coords_tables):
57
  return sql_query_new_coords
58
 
59
  def ask_vanna(query):
60
- location = detect_location_with_openai(OPENAI_API_KEY, query)
61
- if location:
 
62
 
63
- coords = loc2coords(location)
64
- user_input = query.replace(location, f"lat, long : {coords}")
65
- sql_query, result_dataframe, figure = vn.ask(user_input, print_results=False, allow_llm_to_see_data=True)
66
- table = detectTable(sql_query)
67
- coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, table[i]) for i in range(len(table))]
68
- sql_query_new_coords = replace_coordonates(coords, sql_query, coords_tables)
69
- sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\n").content
70
- print("execute sql query : ", sql_with_table_names)
71
- db = sqlite3.connect(db_vanna_path)
72
- result = db.cursor().execute(sql_query_new_coords).fetchall()
73
- columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query_new_coords}").content
74
- columns_list = ast.literal_eval(columns.strip("```python\n").strip())
75
- print("column list : ",columns_list)
76
- df = pd.DataFrame(result, columns=columns_list)
77
-
78
- plotly_code = vn.generate_plotly_code(
79
- question="query",
80
- sql="sql_with_table_names",
81
- df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
82
- )
83
 
84
- fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
85
 
86
- return df, fig
87
- else :
 
 
 
 
 
88
  empty_df = pd.DataFrame()
89
  empty_fig = {}
90
- return empty_df, empty_fig
 
24
 
25
  llm = get_llm(provider="openai")
26
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def replace_coordonates(coords, sql_query, coords_tables):
29
  n = sql_query.count(str(coords[0]))
30
  sql_query_new_coords = sql_query
 
35
  return sql_query_new_coords
36
 
37
  def ask_vanna(query):
38
+ try :
39
+ location = detect_location_with_openai(OPENAI_API_KEY, query)
40
+ if location:
41
 
42
+ coords = loc2coords(location)
43
+ user_input = query.replace(location, f"lat, long : {coords}")
44
+ sql_query, result_dataframe, figure = vn.ask(user_input, print_results=False, allow_llm_to_see_data=True)
45
+ table = detectTable(sql_query)
46
+ coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, table[i]) for i in range(len(table))]
47
+ sql_query_new_coords = replace_coordonates(coords, sql_query, coords_tables)
48
+ sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\n").content
49
+ print("execute sql query : ", sql_with_table_names)
50
+ db = sqlite3.connect(db_vanna_path)
51
+ result = db.cursor().execute(sql_query_new_coords).fetchall()
52
+ columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query_new_coords}").content
53
+ columns_list = ast.literal_eval(columns.strip("```python\n").strip())
54
+ print("column list : ",columns_list)
55
+ df = pd.DataFrame(result, columns=columns_list)
56
+
57
+ plotly_code = vn.generate_plotly_code(
58
+ question="query",
59
+ sql="sql_with_table_names",
60
+ df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
61
+ )
62
 
63
+ fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
64
 
65
+ return df, fig
66
+ else :
67
+ empty_df = pd.DataFrame()
68
+ empty_fig = {}
69
+ return empty_df, empty_fig
70
+ except Exception as e:
71
+ print(f"Error: {e}")
72
  empty_df = pd.DataFrame()
73
  empty_fig = {}
74
+ return empty_df, empty_fig
climateqa/engine/talk_to_data/test_vanna.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import sys\n",
10
+ "import os\n",
11
+ "sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))\n",
12
+ "\n",
13
+ "%load_ext autoreload\n",
14
+ "%autoreload 2\n",
15
+ "\n",
16
+ "from main import ask_vanna\n",
17
+ "import sqlite3\n",
18
+ "import os\n",
19
+ "import pandas as pd"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "table_names_list = [\n",
29
+ " \"Frequency_of_rainy_days_index\",\n",
30
+ " \"Winter_precipitation_total\",\n",
31
+ " \"Summer_precipitation_total\",\n",
32
+ " \"Annual_precipitation_total\",\n",
33
+ " \"Remarkable_daily_precipitation_total_(Q99)\",\n",
34
+ " \"Frequency_of_remarkable_daily_precipitation\",\n",
35
+ " \"Extreme_precipitation_intensity\",\n",
36
+ " \"Mean_winter_temperature\",\n",
37
+ " \"Mean_summer_temperature\",\n",
38
+ " \"Number_of_tropical_nights\",\n",
39
+ " \"Maximum_summer_temperature\",\n",
40
+ " \"Number_of_days_with_Tx_above_30C\",\n",
41
+ " \"Number_of_days_with_Tx_above_35C\",\n",
42
+ " \"Drought_index\"\n",
43
+ "]"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "from climateqa.engine.llm import get_llm\n",
53
+ "\n",
54
+ "llm = get_llm(provider=\"openai\")\n",
55
+ "user_question = \"Quel sera la température à Marseille dans les prochaines années ?\"\n",
56
+ "prompt = f\"You are helping to build a sql query to retrieve relevant data for a user question. The different tables are {table_names_list}. The user question is {user_question}. Write the relevant table to query. Answer only the table name.\"\n",
57
+ "table_name = llm.invoke(prompt).content\n",
58
+ "# llm.invoke(f\"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\\n\").content\n"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "docs = {\"Mean_summer_temperature\": {\n",
68
+ " \"description\": (\n",
69
+ " \"The Mean summer temperature table contains information on the average summer temperature in the past and the future. \"\n",
70
+ " \"The variables are as follows:\\n\"\n",
71
+ " \"- 'y' and 'x': Lambert Paris II coordinates for the location.\\n\"\n",
72
+ " \"- year: Year of the observation.\\n\"\n",
73
+ " \"- month : Month of the observation.\\n\"\n",
74
+ " \"- day: Day of the observation.\\n\"\n",
75
+ " \"- 'LambertParisII': Indicates that the x and y coordinates are in Lambert Paris II projection.\\n\"\n",
76
+ " \"- 'lat' and 'lon': Latitude and longitude of the location.\\n\"\n",
77
+ " \"- 'TMm': Average summer temperature.\\n\"\n",
78
+ " ),\n",
79
+ " \"sql_query\": \"\"\"\n",
80
+ " CREATE TABLE Mean_summer_temperature (\n",
81
+ " y FLOAT,\n",
82
+ " x FLOAT,\n",
83
+ " year INT,\n",
84
+ " month INT, \n",
85
+ " day INT,\n",
86
+ " LambertParisII VARCHAR(255),\n",
87
+ " lat FLOAT,\n",
88
+ " lon FLOAT,\n",
89
+ " TMm FLOAT, -- Température moyenne en été\n",
90
+ " );\n",
91
+ " \"\"\"}\n",
92
+ "}"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "from climateqa.engine.talk_to_data.utils import loc2coords\n",
102
+ "location = \"Marseille\"\n",
103
+ "coords = loc2coords(location)\n",
104
+ "user_input = user_question.replace(location, f\"lat, long : {coords}\")\n"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "initial_prompt = f\"You are a mysql expert. \" + \\\n",
114
+ " \"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \"\n",
115
+ "initial_prompt += f\"\\n===Tables \\n + {docs[table_name]['sql_query']}\"\n",
116
+ "initial_prompt += f\"\\n===Additional Context \\n\\n {docs[table_name]['description']}\"\n",
117
+ "initial_prompt += (\n",
118
+ " \"===Response Guidelines \\n\"\n",
119
+ " \"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \\n\"\n",
120
+ " \"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \\n\"\n",
121
+ " \"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \\n\"\n",
122
+ " \"4. Please use the most relevant table(s). \\n\"\n",
123
+ " \"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \\n\"\n",
124
+ " f\"6. Ensure that the output SQL is mysql-compliant and executable, and free of syntax errors. \\n\"\n",
125
+ " )\n",
126
+ "initial_prompt += f\"\\n===Question \\n\\n {user_input}\"\n"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "sql_query = llm.invoke(initial_prompt).content\n",
136
+ "sql_query"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "# Vanna ask"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "from climateqa.engine.llm import get_llm\n",
153
+ "import ast\n",
154
+ "\n",
155
+ "llm = get_llm(provider=\"openai\")\n",
156
+ "columns = llm.invoke(f\"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query SELECT 'Mean_winter_temperature' AS source_table, year, month, day, TMm FROM Mean_winter_temperature WHERE lat = 43.166954040527344 AND lon = 5.430534839630127;\").content\n",
157
+ "columns_list = ast.literal_eval(columns.strip(\"```python\\n\").strip())\n"
158
+ ]
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "climateqa",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "name": "python",
169
+ "version": "3.11.9"
170
+ }
171
+ },
172
+ "nbformat": 4,
173
+ "nbformat_minor": 2
174
+ }