timeki commited on
Commit
3ca8396
·
verified ·
1 Parent(s): f080183

feature/add_talk_to_data (#23)

Browse files

- Add Talk to Drias (abafbcc3ab2b3f4e5429cfc802caa3e9cdada081)
- Update app.py (a7802dbc4bf7d5916dcc7155319601b30e0716eb)
- Add step by step notebooks for drias (eb90d11b1777bc9e4f2d77b931e82b17a10d16af)

app.py CHANGED
@@ -12,9 +12,11 @@ from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
 
15
 
16
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
17
  from front.utils import process_figures
 
18
 
19
 
20
  from utils import create_user_id
@@ -67,9 +69,9 @@ vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os
67
  vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
68
 
69
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
70
- if os.getenv("ENV")=="GRADIO_ENV":
71
  reranker = get_reranker("nano")
72
- else:
73
  reranker = get_reranker("large")
74
 
75
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
@@ -93,8 +95,9 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_
93
 
94
  # Function to update modal visibility
95
  def update_config_modal_visibility(config_open):
 
96
  new_config_visibility_status = not config_open
97
- return gr.update(visible=new_config_visibility_status), new_config_visibility_status
98
 
99
 
100
  def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
@@ -110,7 +113,21 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
110
 
111
  return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
112
 
113
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # # UI Layout Components
115
  def cqa_tab(tab_name):
116
  # State variables
@@ -142,7 +159,7 @@ def cqa_tab(tab_name):
142
 
143
  # Papers subtab
144
  with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
145
- papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
146
 
147
  # Graphs subtab
148
  with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
@@ -150,6 +167,8 @@ def cqa_tab(tab_name):
150
  "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
151
  elem_id="graphs-container"
152
  )
 
 
153
  return {
154
  "chatbot": chatbot,
155
  "textbox": textbox,
@@ -162,6 +181,7 @@ def cqa_tab(tab_name):
162
  "figures_cards": figures_cards,
163
  "gallery_component": gallery_component,
164
  "config_button": config_button,
 
165
  "papers_html": papers_html,
166
  "citations_network": citations_network,
167
  "papers_summary": papers_summary,
@@ -170,10 +190,23 @@ def cqa_tab(tab_name):
170
  "tab_figures": tab_figures,
171
  "tab_graphs": tab_graphs,
172
  "tab_papers": tab_papers,
173
- "graph_container": graphs_container
 
 
 
174
  }
175
 
176
-
 
 
 
 
 
 
 
 
 
 
177
 
178
  def event_handling(
179
  main_tab_components,
@@ -190,7 +223,8 @@ def event_handling(
190
  sources_textbox = main_tab_components["sources_textbox"]
191
  figures_cards = main_tab_components["figures_cards"]
192
  gallery_component = main_tab_components["gallery_component"]
193
- config_button = main_tab_components["config_button"]
 
194
  papers_html = main_tab_components["papers_html"]
195
  citations_network = main_tab_components["citations_network"]
196
  papers_summary = main_tab_components["papers_summary"]
@@ -200,9 +234,13 @@ def event_handling(
200
  tab_graphs = main_tab_components["tab_graphs"]
201
  tab_papers = main_tab_components["tab_papers"]
202
  graphs_container = main_tab_components["graph_container"]
 
 
 
203
 
204
- config_open = config_components["config_open"]
205
- config_modal = config_components["config_modal"]
 
206
  dropdown_sources = config_components["dropdown_sources"]
207
  dropdown_reports = config_components["dropdown_reports"]
208
  dropdown_external_sources = config_components["dropdown_external_sources"]
@@ -211,18 +249,18 @@ def event_handling(
211
  after = config_components["after"]
212
  output_query = config_components["output_query"]
213
  output_language = config_components["output_language"]
214
- close_config_modal = config_components["close_config_modal_button"]
215
 
216
  new_sources_hmtl = gr.State([])
217
-
218
- print("textbox id : ", textbox.elem_id)
219
 
220
- for button in [config_button, close_config_modal]:
221
- button.click(
222
- fn=update_config_modal_visibility,
223
- inputs=[config_open],
224
- outputs=[config_modal, config_open]
225
- )
226
 
227
  if tab_name == "ClimateQ&A":
228
  print("chat cqa - message sent")
@@ -265,10 +303,13 @@ def event_handling(
265
  component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
266
 
267
  # Search for papers
268
- for component in [textbox, examples_hidden]:
269
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
270
 
271
 
 
 
 
272
 
273
  def main_ui():
274
  # config_open = gr.State(True)
@@ -278,11 +319,14 @@ def main_ui():
278
  with gr.Tabs():
279
  cqa_components = cqa_tab(tab_name = "ClimateQ&A")
280
  local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
 
281
 
282
  create_about_tab()
283
 
284
  event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
285
- event_handling(local_cqa_components, config_components, tab_name = 'Beta - POC Adapt\'Action')
 
 
286
 
287
  demo.queue()
288
 
 
12
  from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
+ from climateqa.engine.talk_to_data.main import ask_vanna
16
 
17
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
18
  from front.utils import process_figures
19
+ from gradio_modal import Modal
20
 
21
 
22
  from utils import create_user_id
 
69
  vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
70
 
71
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
72
+ if os.environ["GRADIO_ENV"] == "local":
73
  reranker = get_reranker("nano")
74
+ else :
75
  reranker = get_reranker("large")
76
 
77
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
 
95
 
96
  # Function to update modal visibility
97
  def update_config_modal_visibility(config_open):
98
+ print(config_open)
99
  new_config_visibility_status = not config_open
100
+ return Modal(visible=new_config_visibility_status), new_config_visibility_status
101
 
102
 
103
  def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
 
113
 
114
  return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
115
 
116
+ def create_drias_tab():
117
+ with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
118
+ vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
119
+ with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
120
+ vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
121
+ show_vanna_table = gr.Button("Show Table", elem_id="show-table")
122
+ with Modal(visible=False) as vanna_table_modal:
123
+ vanna_table = gr.DataFrame([], elem_id="vanna-table")
124
+ close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
125
+ close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
126
+ show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
127
+
128
+ vanna_display = gr.Plot()
129
+ vanna_direct_question.submit(ask_vanna, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
130
+
131
  # # UI Layout Components
132
  def cqa_tab(tab_name):
133
  # State variables
 
159
 
160
  # Papers subtab
161
  with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
162
+ papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
163
 
164
  # Graphs subtab
165
  with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
 
167
  "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
168
  elem_id="graphs-container"
169
  )
170
+
171
+
172
  return {
173
  "chatbot": chatbot,
174
  "textbox": textbox,
 
181
  "figures_cards": figures_cards,
182
  "gallery_component": gallery_component,
183
  "config_button": config_button,
184
+ "papers_direct_search" : papers_direct_search,
185
  "papers_html": papers_html,
186
  "citations_network": citations_network,
187
  "papers_summary": papers_summary,
 
190
  "tab_figures": tab_figures,
191
  "tab_graphs": tab_graphs,
192
  "tab_papers": tab_papers,
193
+ "graph_container": graphs_container,
194
+ # "vanna_sql_query": vanna_sql_query,
195
+ # "vanna_table" : vanna_table,
196
+ # "vanna_display": vanna_display
197
  }
198
 
199
+ def config_event_handling(main_tabs_components : list[dict], config_componenets : dict):
200
+ config_open = config_componenets["config_open"]
201
+ config_modal = config_componenets["config_modal"]
202
+ close_config_modal = config_componenets["close_config_modal_button"]
203
+
204
+ for button in [close_config_modal] + [main_tab_component["config_button"] for main_tab_component in main_tabs_components]:
205
+ button.click(
206
+ fn=update_config_modal_visibility,
207
+ inputs=[config_open],
208
+ outputs=[config_modal, config_open]
209
+ )
210
 
211
  def event_handling(
212
  main_tab_components,
 
223
  sources_textbox = main_tab_components["sources_textbox"]
224
  figures_cards = main_tab_components["figures_cards"]
225
  gallery_component = main_tab_components["gallery_component"]
226
+ # config_button = main_tab_components["config_button"]
227
+ papers_direct_search = main_tab_components["papers_direct_search"]
228
  papers_html = main_tab_components["papers_html"]
229
  citations_network = main_tab_components["citations_network"]
230
  papers_summary = main_tab_components["papers_summary"]
 
234
  tab_graphs = main_tab_components["tab_graphs"]
235
  tab_papers = main_tab_components["tab_papers"]
236
  graphs_container = main_tab_components["graph_container"]
237
+ # vanna_sql_query = main_tab_components["vanna_sql_query"]
238
+ # vanna_table = main_tab_components["vanna_table"]
239
+ # vanna_display = main_tab_components["vanna_display"]
240
 
241
+
242
+ # config_open = config_components["config_open"]
243
+ # config_modal = config_components["config_modal"]
244
  dropdown_sources = config_components["dropdown_sources"]
245
  dropdown_reports = config_components["dropdown_reports"]
246
  dropdown_external_sources = config_components["dropdown_external_sources"]
 
249
  after = config_components["after"]
250
  output_query = config_components["output_query"]
251
  output_language = config_components["output_language"]
252
+ # close_config_modal = config_components["close_config_modal_button"]
253
 
254
  new_sources_hmtl = gr.State([])
255
+ ttd_data = gr.State([])
256
+
257
 
258
+ # for button in [config_button, close_config_modal]:
259
+ # button.click(
260
+ # fn=update_config_modal_visibility,
261
+ # inputs=[config_open],
262
+ # outputs=[config_modal, config_open]
263
+ # )
264
 
265
  if tab_name == "ClimateQ&A":
266
  print("chat cqa - message sent")
 
303
  component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
304
 
305
  # Search for papers
306
+ for component in [textbox, examples_hidden, papers_direct_search]:
307
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
308
 
309
 
310
+ # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
311
+ # # Drias search
312
+ # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
313
 
314
  def main_ui():
315
  # config_open = gr.State(True)
 
319
  with gr.Tabs():
320
  cqa_components = cqa_tab(tab_name = "ClimateQ&A")
321
  local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
322
+ create_drias_tab()
323
 
324
  create_about_tab()
325
 
326
  event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
327
+ event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action")
328
+
329
+ config_event_handling([cqa_components,local_cqa_components] ,config_components)
330
 
331
  demo.queue()
332
 
climateqa/chat.py CHANGED
@@ -53,6 +53,13 @@ def log_interaction_to_azure(history, output_query, sources, docs, share_client,
53
  print(f"Error logging on Azure Blob Storage: {e}")
54
  error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
55
  raise gr.Error(error_msg)
 
 
 
 
 
 
 
56
 
57
  # Main chat function
58
  async def chat_stream(
@@ -121,6 +128,7 @@ async def chat_stream(
121
  used_documents = []
122
  retrieved_contents = []
123
  answer_message_content = ""
 
124
 
125
  # Define processing steps
126
  steps_display = {
@@ -142,6 +150,14 @@ async def chat_stream(
142
  history, used_documents, retrieved_contents = handle_retrieved_documents(
143
  event, history, used_documents, retrieved_contents
144
  )
 
 
 
 
 
 
 
 
145
  if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
146
  docs = event["data"]["input"]["documents"]
147
  docs_html = convert_to_docs_to_html(docs)
@@ -184,7 +200,7 @@ async def chat_stream(
184
  sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
185
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
186
 
187
- yield history, docs_html, output_query, output_language, related_contents, graphs_html
188
 
189
  except Exception as e:
190
  print(f"Event {event} has failed")
@@ -195,4 +211,4 @@ async def chat_stream(
195
  # Call the function to log interaction
196
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
197
 
198
- yield history, docs_html, output_query, output_language, related_contents, graphs_html
 
53
  print(f"Error logging on Azure Blob Storage: {e}")
54
  error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
55
  raise gr.Error(error_msg)
56
+
57
+ def handle_numerical_data(event):
58
+ if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
59
+ numerical_data = event["data"]["output"]["drias_data"]
60
+ sql_query = event["data"]["output"]["drias_sql_query"]
61
+ return numerical_data, sql_query
62
+ return None, None
63
 
64
  # Main chat function
65
  async def chat_stream(
 
128
  used_documents = []
129
  retrieved_contents = []
130
  answer_message_content = ""
131
+ vanna_data = {}
132
 
133
  # Define processing steps
134
  steps_display = {
 
150
  history, used_documents, retrieved_contents = handle_retrieved_documents(
151
  event, history, used_documents, retrieved_contents
152
  )
153
+ # Handle Vanna retrieval
154
+ # if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
155
+ # df_output_vanna, sql_query = handle_numerical_data(
156
+ # event
157
+ # )
158
+ # vanna_data = {"df_output": df_output_vanna, "sql_query": sql_query}
159
+
160
+
161
  if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
162
  docs = event["data"]["input"]["documents"]
163
  docs_html = convert_to_docs_to_html(docs)
 
200
  sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
201
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
202
 
203
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
204
 
205
  except Exception as e:
206
  print(f"Event {event} has failed")
 
211
  # Call the function to log interaction
212
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
213
 
214
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
climateqa/engine/talk_to_data/main.py CHANGED
@@ -19,7 +19,7 @@ VANNA_MODEL = os.getenv('VANNA_MODEL')
19
 
20
  #Vanna object
21
  vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
22
- db_vanna_path = os.path.join(os.path.dirname(__file__), "database/drias.db")
23
  vn.connect_to_sqlite(db_vanna_path)
24
 
25
  llm = get_llm(provider="openai")
 
19
 
20
  #Vanna object
21
  vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
22
+ db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
23
  vn.connect_to_sqlite(db_vanna_path)
24
 
25
  llm = get_llm(provider="openai")
climateqa/engine/talk_to_data/myVanna.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
3
+ from vanna.openai import OpenAI_Chat
4
+ import os
5
+
6
+ load_dotenv()
7
+
8
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
9
+
10
+ class MyVanna(MyCustomVectorDB, OpenAI_Chat):
11
+ def __init__(self, config=None):
12
+ MyCustomVectorDB.__init__(self, config=config)
13
+ OpenAI_Chat.__init__(self, config=config)
climateqa/engine/talk_to_data/utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import openai
3
+ import pandas as pd
4
+ from geopy.geocoders import Nominatim
5
+ import sqlite3
6
+ import ast
7
+
8
+
9
+ def detect_location_with_openai(api_key, sentence):
10
+ """
11
+ Detects locations in a sentence using OpenAI's API.
12
+ """
13
+ openai.api_key = api_key
14
+
15
+ prompt = f"""
16
+ Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
17
+ Return the result as a Python list. If no locations are mentioned, return an empty list.
18
+
19
+ Sentence: "{sentence}"
20
+ """
21
+
22
+ response = openai.chat.completions.create(
23
+ model="gpt-4o-mini",
24
+ messages=[
25
+ {"role": "system", "content": "You are a helpful assistant skilled in identifying locations in text."},
26
+ {"role": "user", "content": prompt}
27
+ ],
28
+ max_tokens=100,
29
+ temperature=0
30
+ )
31
+
32
+ return response.choices[0].message.content.split("\n")[1][2:-2]
33
+
34
+
35
+ def detectTable(sql_query):
36
+ pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
37
+ matches = re.findall(pattern, sql_query)
38
+ return matches
39
+
40
+
41
+
42
+ def loc2coords(location : str):
43
+ geolocator = Nominatim(user_agent="city_to_latlong")
44
+ location = geolocator.geocode(location)
45
+ return (location.latitude, location.longitude)
46
+
47
+
48
+ def coords2loc(coords : tuple):
49
+ geolocator = Nominatim(user_agent="coords_to_city")
50
+ try:
51
+ location = geolocator.reverse(coords)
52
+ return location.address
53
+ except Exception as e:
54
+ print(f"Error: {e}")
55
+ return "Unknown Location"
56
+
57
+
58
+ def nearestNeighbourSQL(db: str, location: tuple, table : str):
59
+ conn = sqlite3.connect(db)
60
+ long = round(location[1], 3)
61
+ lat = round(location[0], 3)
62
+ cursor = conn.cursor()
63
+ cursor.execute(f"SELECT lat, lon FROM {table} WHERE lat BETWEEN {lat - 0.3} AND {lat + 0.3} AND lon BETWEEN {long - 0.3} AND {long + 0.3}")
64
+ results = cursor.fetchall()
65
+ return results[0]
66
+
67
+ def detect_relevant_tables(user_question, llm):
68
+ table_names_list = [
69
+ "Frequency_of_rainy_days_index",
70
+ "Winter_precipitation_total",
71
+ "Summer_precipitation_total",
72
+ "Annual_precipitation_total",
73
+ # "Remarkable_daily_precipitation_total_(Q99)",
74
+ "Frequency_of_remarkable_daily_precipitation",
75
+ "Extreme_precipitation_intensity",
76
+ "Mean_winter_temperature",
77
+ "Mean_summer_temperature",
78
+ "Number_of_tropical_nights",
79
+ "Maximum_summer_temperature",
80
+ "Number_of_days_with_Tx_above_30C",
81
+ "Number_of_days_with_Tx_above_35C",
82
+ "Drought_index"
83
+ ]
84
+ prompt = (
85
+ f"You are helping to build a sql query to retrieve relevant data for a user question."
86
+ f"The different tables are {table_names_list}."
87
+ f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
88
+ )
89
+ table_names = ast.literal_eval(llm.invoke(prompt).content.strip("```python\n").strip())
90
+ return table_names
91
+
92
+ def replace_coordonates(coords, query, coords_tables):
93
+ n = query.count(str(coords[0]))
94
+
95
+ for i in range(n):
96
+ query = query.replace(str(coords[0]), str(coords_tables[i][0]),1)
97
+ query = query.replace(str(coords[1]), str(coords_tables[i][1]),1)
98
+ return query
climateqa/engine/talk_to_data/vanna_class.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vanna.base import VannaBase
2
+ from pinecone import Pinecone
3
+ from climateqa.engine.embeddings import get_embeddings_function
4
+ import pandas as pd
5
+ import hashlib
6
+
7
+ class MyCustomVectorDB(VannaBase):
8
+
9
+ """
10
+ VectorDB class for storing and retrieving vectors from Pinecone.
11
+
12
+ args :
13
+ config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
14
+ - pc_api_key (str) : Pinecone API key
15
+ - index_name (str) : Pinecone index name
16
+ - top_k (int) : Number of top results to return (default = 2)
17
+
18
+ """
19
+
20
+ def __init__(self,config):
21
+ super().__init__(config = config)
22
+ try :
23
+ self.api_key = config.get('pc_api_key')
24
+ self.index_name = config.get('index_name')
25
+ except :
26
+ raise Exception("Please provide the Pinecone API key and the index name")
27
+
28
+ self.pc = Pinecone(api_key = self.api_key)
29
+ self.index = self.pc.Index(self.index_name)
30
+ self.top_k = config.get('top_k', 2)
31
+ self.embeddings = get_embeddings_function()
32
+
33
+
34
+ def check_embedding(self, id, namespace):
35
+ fetched = self.index.fetch(ids = [id], namespace = namespace)
36
+ if fetched['vectors'] == {}:
37
+ return False
38
+ return True
39
+
40
+ def generate_hash_id(self, data: str) -> str:
41
+ """
42
+ Generate a unique hash ID for the given data.
43
+
44
+ Args:
45
+ data (str): The input data to hash (e.g., a concatenated string of user attributes).
46
+
47
+ Returns:
48
+ str: A unique hash ID as a hexadecimal string.
49
+ """
50
+
51
+ data_bytes = data.encode('utf-8')
52
+ hash_object = hashlib.sha256(data_bytes)
53
+ hash_id = hash_object.hexdigest()
54
+
55
+ return hash_id
56
+
57
+ def add_ddl(self, ddl: str, **kwargs) -> str:
58
+ id = self.generate_hash_id(ddl) + '_ddl'
59
+
60
+ if self.check_embedding(id, 'ddl'):
61
+ print(f"DDL having id {id} already exists")
62
+ return id
63
+
64
+ self.index.upsert(
65
+ vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
66
+ namespace = 'ddl'
67
+ )
68
+
69
+ return id
70
+
71
+ def add_documentation(self, doc: str, **kwargs) -> str:
72
+ id = self.generate_hash_id(doc) + '_doc'
73
+
74
+ if self.check_embedding(id, 'documentation'):
75
+ print(f"Documentation having id {id} already exists")
76
+ return id
77
+
78
+ self.index.upsert(
79
+ vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
80
+ namespace = 'documentation'
81
+ )
82
+
83
+ return id
84
+
85
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
86
+ id = self.generate_hash_id(question) + '_sql'
87
+
88
+ if self.check_embedding(id, 'question_sql'):
89
+ print(f"Question-SQL pair having id {id} already exists")
90
+ return id
91
+
92
+ self.index.upsert(
93
+ vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
94
+ namespace = 'question_sql'
95
+ )
96
+
97
+ return id
98
+
99
+ def get_related_ddl(self, question: str, **kwargs) -> list:
100
+ res = self.index.query(
101
+ vector=self.embeddings.embed_query(question),
102
+ top_k=self.top_k,
103
+ namespace='ddl',
104
+ include_metadata=True
105
+ )
106
+
107
+ return [match['metadata']['ddl'] for match in res['matches']]
108
+
109
+ def get_related_documentation(self, question: str, **kwargs) -> list:
110
+ res = self.index.query(
111
+ vector=self.embeddings.embed_query(question),
112
+ top_k=self.top_k,
113
+ namespace='documentation',
114
+ include_metadata=True
115
+ )
116
+
117
+ return [match['metadata']['doc'] for match in res['matches']]
118
+
119
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
120
+ res = self.index.query(
121
+ vector=self.embeddings.embed_query(question),
122
+ top_k=self.top_k,
123
+ namespace='question_sql',
124
+ include_metadata=True
125
+ )
126
+
127
+ return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
128
+
129
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
130
+
131
+ list_of_data = []
132
+
133
+ namespaces = ['ddl', 'documentation', 'question_sql']
134
+
135
+ for namespace in namespaces:
136
+
137
+ data = self.index.query(
138
+ top_k=10000,
139
+ namespace=namespace,
140
+ include_metadata=True,
141
+ include_values=False
142
+ )
143
+
144
+ for match in data['matches']:
145
+ list_of_data.append(match['metadata'])
146
+
147
+ return pd.DataFrame(list_of_data)
148
+
149
+
150
+
151
+ def remove_training_data(self, id: str, **kwargs) -> bool:
152
+ if id.endswith("_ddl"):
153
+ self.Index.delete(ids=[id], namespace="_ddl")
154
+ return True
155
+ if id.endswith("_sql"):
156
+ self.index.delete(ids=[id], namespace="_sql")
157
+ return True
158
+
159
+ if id.endswith("_doc"):
160
+ self.Index.delete(ids=[id], namespace="_doc")
161
+ return True
162
+
163
+ return False
164
+
165
+ def generate_embedding(self, text, **kwargs):
166
+ # Implement the method here
167
+ pass
168
+
169
+
170
+ def get_sql_prompt(
171
+ self,
172
+ initial_prompt : str,
173
+ question: str,
174
+ question_sql_list: list,
175
+ ddl_list: list,
176
+ doc_list: list,
177
+ **kwargs,
178
+ ):
179
+ """
180
+ Example:
181
+ ```python
182
+ vn.get_sql_prompt(
183
+ question="What are the top 10 customers by sales?",
184
+ question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
185
+ ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
186
+ doc_list=["The customers table contains information about customers and their sales."],
187
+ )
188
+
189
+ ```
190
+
191
+ This method is used to generate a prompt for the LLM to generate SQL.
192
+
193
+ Args:
194
+ question (str): The question to generate SQL for.
195
+ question_sql_list (list): A list of questions and their corresponding SQL statements.
196
+ ddl_list (list): A list of DDL statements.
197
+ doc_list (list): A list of documentation.
198
+
199
+ Returns:
200
+ any: The prompt for the LLM to generate SQL.
201
+ """
202
+
203
+ if initial_prompt is None:
204
+ initial_prompt = f"You are a {self.dialect} expert. " + \
205
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
206
+
207
+ initial_prompt = self.add_ddl_to_prompt(
208
+ initial_prompt, ddl_list, max_tokens=self.max_tokens
209
+ )
210
+
211
+ if self.static_documentation != "":
212
+ doc_list.append(self.static_documentation)
213
+
214
+ initial_prompt = self.add_documentation_to_prompt(
215
+ initial_prompt, doc_list, max_tokens=self.max_tokens
216
+ )
217
+
218
+ # initial_prompt = self.add_sql_to_prompt(
219
+ # initial_prompt, question_sql_list, max_tokens=self.max_tokens
220
+ # )
221
+
222
+
223
+ initial_prompt += (
224
+ "===Response Guidelines \n"
225
+ "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
226
+ "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
227
+ "3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
228
+ "4. Please use the most relevant table(s). \n"
229
+ "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
230
+ f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
231
+ f"7. Add a description of the table in the result of the sql query, if relevant. \n"
232
+ "8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
233
+ # f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
234
+ # "7. Add a description of the table in the result of the sql query."
235
+ # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
236
+ # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
237
+ )
238
+
239
+
240
+ message_log = [self.system_message(initial_prompt)]
241
+
242
+ for example in question_sql_list:
243
+ if example is None:
244
+ print("example is None")
245
+ else:
246
+ if example is not None and "question" in example and "sql" in example:
247
+ message_log.append(self.user_message(example["question"]))
248
+ message_log.append(self.assistant_message(example["sql"]))
249
+
250
+ message_log.append(self.user_message(question))
251
+
252
+ return message_log
253
+
254
+
255
+ # def get_sql_prompt(
256
+ # self,
257
+ # initial_prompt : str,
258
+ # question: str,
259
+ # question_sql_list: list,
260
+ # ddl_list: list,
261
+ # doc_list: list,
262
+ # **kwargs,
263
+ # ):
264
+ # """
265
+ # Example:
266
+ # ```python
267
+ # vn.get_sql_prompt(
268
+ # question="What are the top 10 customers by sales?",
269
+ # question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
270
+ # ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
271
+ # doc_list=["The customers table contains information about customers and their sales."],
272
+ # )
273
+
274
+ # ```
275
+
276
+ # This method is used to generate a prompt for the LLM to generate SQL.
277
+
278
+ # Args:
279
+ # question (str): The question to generate SQL for.
280
+ # question_sql_list (list): A list of questions and their corresponding SQL statements.
281
+ # ddl_list (list): A list of DDL statements.
282
+ # doc_list (list): A list of documentation.
283
+
284
+ # Returns:
285
+ # any: The prompt for the LLM to generate SQL.
286
+ # """
287
+
288
+ # if initial_prompt is None:
289
+ # initial_prompt = f"You are a {self.dialect} expert. " + \
290
+ # "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
291
+
292
+ # initial_prompt = self.add_ddl_to_prompt(
293
+ # initial_prompt, ddl_list, max_tokens=self.max_tokens
294
+ # )
295
+
296
+ # if self.static_documentation != "":
297
+ # doc_list.append(self.static_documentation)
298
+
299
+ # initial_prompt = self.add_documentation_to_prompt(
300
+ # initial_prompt, doc_list, max_tokens=self.max_tokens
301
+ # )
302
+
303
+ # initial_prompt += (
304
+ # "===Response Guidelines \n"
305
+ # "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
306
+ # "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
307
+ # "3. If the provided context is insufficient, please explain why it can't be generated. \n"
308
+ # "4. Please use the most relevant table(s). \n"
309
+ # "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
310
+ # f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
311
+ # )
312
+
313
+ # message_log = [self.system_message(initial_prompt)]
314
+
315
+ # for example in question_sql_list:
316
+ # if example is None:
317
+ # print("example is None")
318
+ # else:
319
+ # if example is not None and "question" in example and "sql" in example:
320
+ # message_log.append(self.user_message(example["question"]))
321
+ # message_log.append(self.assistant_message(example["sql"]))
322
+
323
+ # message_log.append(self.user_message(question))
324
+
325
+ # return message_log
front/tabs/tab_papers.py CHANGED
@@ -3,6 +3,8 @@ from gradio_modal import Modal
3
 
4
 
5
  def create_papers_tab():
 
 
6
  with gr.Accordion(
7
  visible=True,
8
  elem_id="papers-summary-popup",
@@ -32,5 +34,5 @@ def create_papers_tab():
32
  papers_modal
33
  )
34
 
35
- return papers_summary, papers_html, citations_network, papers_modal
36
 
 
3
 
4
 
5
  def create_papers_tab():
6
+ direct_search_textbox = gr.Textbox(label="Direct search for papers", placeholder= "What is climate change ?", elem_id="papers-search")
7
+
8
  with gr.Accordion(
9
  visible=True,
10
  elem_id="papers-summary-popup",
 
34
  papers_modal
35
  )
36
 
37
+ return direct_search_textbox, papers_summary, papers_html, citations_network, papers_modal
38
 
requirements.txt CHANGED
@@ -19,3 +19,5 @@ langchain-community==0.2
19
  msal==1.31
20
  matplotlib==3.9.2
21
  gradio-modal==0.0.4
 
 
 
19
  msal==1.31
20
  matplotlib==3.9.2
21
  gradio-modal==0.0.4
22
+ vanna==0.7.5
23
+ geopy==2.4.1
sandbox/talk_to_data/20250306 - CQA - Drias.ipynb ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Import the function in main.py"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import sys\n",
17
+ "import os\n",
18
+ "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
19
+ "\n",
20
+ "%load_ext autoreload\n",
21
+ "%autoreload 2\n",
22
+ "\n",
23
+ "from climateqa.engine.talk_to_data.main import ask_vanna\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {},
29
+ "source": [
30
+ "## Create a human query"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "query = \"Comment vont évoluer les températures à marseille ?\""
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {},
45
+ "source": [
46
+ "## Call the function ask vanna, it gives an output of a the sql query and the dataframe of the result (tuple)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "sql_query, df, fig = ask_vanna(query)\n",
56
+ "print(df.head())\n",
57
+ "fig.show()"
58
+ ]
59
+ }
60
+ ],
61
+ "metadata": {
62
+ "kernelspec": {
63
+ "display_name": "climateqa",
64
+ "language": "python",
65
+ "name": "python3"
66
+ },
67
+ "language_info": {
68
+ "codemirror_mode": {
69
+ "name": "ipython",
70
+ "version": 3
71
+ },
72
+ "file_extension": ".py",
73
+ "mimetype": "text/x-python",
74
+ "name": "python",
75
+ "nbconvert_exporter": "python",
76
+ "pygments_lexer": "ipython3",
77
+ "version": "3.11.9"
78
+ }
79
+ },
80
+ "nbformat": 4,
81
+ "nbformat_minor": 2
82
+ }
sandbox/talk_to_data/20250306 - CQA - Step_by_step_vanna.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import sys\n",
10
+ "import os\n",
11
+ "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
12
+ "\n",
13
+ "%load_ext autoreload\n",
14
+ "%autoreload 2\n",
15
+ "\n",
16
+ "from climateqa.engine.talk_to_data.main import ask_vanna\n",
17
+ "\n",
18
+ "import sqlite3\n",
19
+ "import os\n",
20
+ "import pandas as pd"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "metadata": {},
26
+ "source": [
27
+ "# Imports"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "from climateqa.engine.talk_to_data.myVanna import MyVanna\n",
37
+ "from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates#,nearestNeighbourPostgres\n",
38
+ "\n",
39
+ "from climateqa.engine.llm import get_llm"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {},
45
+ "source": [
46
+ "# Vanna Ask\n"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "from dotenv import load_dotenv\n",
56
+ "\n",
57
+ "load_dotenv()\n",
58
+ "\n",
59
+ "llm = get_llm(provider=\"openai\")\n",
60
+ "\n",
61
+ "OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')\n",
62
+ "PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')\n",
63
+ "INDEX_NAME = os.getenv('VANNA_INDEX_NAME')\n",
64
+ "VANNA_MODEL = os.getenv('VANNA_MODEL')\n",
65
+ "\n",
66
+ "ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))\n",
67
+ "\n",
68
+ "#Vanna object\n",
69
+ "vn = MyVanna(config = {\"temperature\": 0, \"api_key\": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, \"top_k\" : 4})\n",
70
+ "\n",
71
+ "db_vanna_path = ROOT_PATH + \"/data/drias/drias.db\"\n",
72
+ "vn.connect_to_sqlite(db_vanna_path)\n"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "# User query"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "query = \"Quelle sera la température à Marseille sur les prochaines années ?\""
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "## Detect location"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "location = detect_location_with_openai(OPENAI_API_KEY, query)\n",
105
+ "print(location)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "## Convert location to longitude, latitude coordonate"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "coords = loc2coords(location)\n",
122
+ "user_input = query.lower().replace(location.lower(), f\"lat, long : {coords}\")\n",
123
+ "print(user_input)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "# Find closest coordonates and replace lat,lon\n"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "relevant_tables = detect_relevant_tables(user_input, llm) \n",
140
+ "coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]\n",
141
+ "user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)\n",
142
+ "print(user_input_with_coords)"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "metadata": {},
148
+ "source": [
149
+ "# Ask Vanna with correct coordonates"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "user_input_with_coords"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)\n",
168
+ "print(result_dataframe.head())"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "result_dataframe"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "figure"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": []
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "climateqa",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.11.9"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 2
218
+ }
style.css CHANGED
@@ -481,14 +481,13 @@ a {
481
  max-height: calc(100vh - 190px) !important;
482
  overflow: hidden;
483
  }
484
-
485
  div#tab-examples,
486
  div#sources-textbox,
487
  div#tab-config {
488
  height: calc(100vh - 190px) !important;
489
  overflow-y: scroll !important;
490
  }
491
-
492
  div#sources-figures,
493
  div#graphs-container,
494
  div#tab-citations {
@@ -606,3 +605,16 @@ a {
606
  #checkbox-config:checked {
607
  display: block;
608
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  max-height: calc(100vh - 190px) !important;
482
  overflow: hidden;
483
  }
 
484
  div#tab-examples,
485
  div#sources-textbox,
486
  div#tab-config {
487
  height: calc(100vh - 190px) !important;
488
  overflow-y: scroll !important;
489
  }
490
+ div#tab-vanna,
491
  div#sources-figures,
492
  div#graphs-container,
493
  div#tab-citations {
 
605
  #checkbox-config:checked {
606
  display: block;
607
  }
608
+
609
+ #vanna-display {
610
+ max-height: 300px;
611
+ /* overflow-y: scroll; */
612
+ }
613
+ #sql-query{
614
+ max-height: 100px;
615
+ overflow-y:scroll;
616
+ }
617
+ #vanna-details{
618
+ max-height: 500px;
619
+ overflow-y:scroll;
620
+ }