feature/add_talk_to_data (#23)
Browse files- Add Talk to Drias (abafbcc3ab2b3f4e5429cfc802caa3e9cdada081)
- Update app.py (a7802dbc4bf7d5916dcc7155319601b30e0716eb)
- Add step by step notebooks for drias (eb90d11b1777bc9e4f2d77b931e82b17a10d16af)
- app.py +65 -21
- climateqa/chat.py +18 -2
- climateqa/engine/talk_to_data/main.py +1 -1
- climateqa/engine/talk_to_data/myVanna.py +13 -0
- climateqa/engine/talk_to_data/utils.py +98 -0
- climateqa/engine/talk_to_data/vanna_class.py +325 -0
- front/tabs/tab_papers.py +3 -1
- requirements.txt +2 -0
- sandbox/talk_to_data/20250306 - CQA - Drias.ipynb +82 -0
- sandbox/talk_to_data/20250306 - CQA - Step_by_step_vanna.ipynb +218 -0
- style.css +14 -2
app.py
CHANGED
@@ -12,9 +12,11 @@ from climateqa.engine.reranker import get_reranker
|
|
12 |
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
|
15 |
|
16 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
17 |
from front.utils import process_figures
|
|
|
18 |
|
19 |
|
20 |
from utils import create_user_id
|
@@ -67,9 +69,9 @@ vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os
|
|
67 |
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
|
68 |
|
69 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
70 |
-
if os.
|
71 |
reranker = get_reranker("nano")
|
72 |
-
else:
|
73 |
reranker = get_reranker("large")
|
74 |
|
75 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
@@ -93,8 +95,9 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_
|
|
93 |
|
94 |
# Function to update modal visibility
|
95 |
def update_config_modal_visibility(config_open):
|
|
|
96 |
new_config_visibility_status = not config_open
|
97 |
-
return
|
98 |
|
99 |
|
100 |
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
|
@@ -110,7 +113,21 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
|
|
110 |
|
111 |
return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
|
112 |
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
# # UI Layout Components
|
115 |
def cqa_tab(tab_name):
|
116 |
# State variables
|
@@ -142,7 +159,7 @@ def cqa_tab(tab_name):
|
|
142 |
|
143 |
# Papers subtab
|
144 |
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
145 |
-
papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
146 |
|
147 |
# Graphs subtab
|
148 |
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
@@ -150,6 +167,8 @@ def cqa_tab(tab_name):
|
|
150 |
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
|
151 |
elem_id="graphs-container"
|
152 |
)
|
|
|
|
|
153 |
return {
|
154 |
"chatbot": chatbot,
|
155 |
"textbox": textbox,
|
@@ -162,6 +181,7 @@ def cqa_tab(tab_name):
|
|
162 |
"figures_cards": figures_cards,
|
163 |
"gallery_component": gallery_component,
|
164 |
"config_button": config_button,
|
|
|
165 |
"papers_html": papers_html,
|
166 |
"citations_network": citations_network,
|
167 |
"papers_summary": papers_summary,
|
@@ -170,10 +190,23 @@ def cqa_tab(tab_name):
|
|
170 |
"tab_figures": tab_figures,
|
171 |
"tab_graphs": tab_graphs,
|
172 |
"tab_papers": tab_papers,
|
173 |
-
"graph_container": graphs_container
|
|
|
|
|
|
|
174 |
}
|
175 |
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
def event_handling(
|
179 |
main_tab_components,
|
@@ -190,7 +223,8 @@ def event_handling(
|
|
190 |
sources_textbox = main_tab_components["sources_textbox"]
|
191 |
figures_cards = main_tab_components["figures_cards"]
|
192 |
gallery_component = main_tab_components["gallery_component"]
|
193 |
-
config_button = main_tab_components["config_button"]
|
|
|
194 |
papers_html = main_tab_components["papers_html"]
|
195 |
citations_network = main_tab_components["citations_network"]
|
196 |
papers_summary = main_tab_components["papers_summary"]
|
@@ -200,9 +234,13 @@ def event_handling(
|
|
200 |
tab_graphs = main_tab_components["tab_graphs"]
|
201 |
tab_papers = main_tab_components["tab_papers"]
|
202 |
graphs_container = main_tab_components["graph_container"]
|
|
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
|
|
206 |
dropdown_sources = config_components["dropdown_sources"]
|
207 |
dropdown_reports = config_components["dropdown_reports"]
|
208 |
dropdown_external_sources = config_components["dropdown_external_sources"]
|
@@ -211,18 +249,18 @@ def event_handling(
|
|
211 |
after = config_components["after"]
|
212 |
output_query = config_components["output_query"]
|
213 |
output_language = config_components["output_language"]
|
214 |
-
close_config_modal = config_components["close_config_modal_button"]
|
215 |
|
216 |
new_sources_hmtl = gr.State([])
|
217 |
-
|
218 |
-
|
219 |
|
220 |
-
for button in [config_button, close_config_modal]:
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
|
227 |
if tab_name == "ClimateQ&A":
|
228 |
print("chat cqa - message sent")
|
@@ -265,10 +303,13 @@ def event_handling(
|
|
265 |
component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
266 |
|
267 |
# Search for papers
|
268 |
-
for component in [textbox, examples_hidden]:
|
269 |
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
270 |
|
271 |
|
|
|
|
|
|
|
272 |
|
273 |
def main_ui():
|
274 |
# config_open = gr.State(True)
|
@@ -278,11 +319,14 @@ def main_ui():
|
|
278 |
with gr.Tabs():
|
279 |
cqa_components = cqa_tab(tab_name = "ClimateQ&A")
|
280 |
local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
|
|
|
281 |
|
282 |
create_about_tab()
|
283 |
|
284 |
event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
|
285 |
-
event_handling(local_cqa_components, config_components, tab_name =
|
|
|
|
|
286 |
|
287 |
demo.queue()
|
288 |
|
|
|
12 |
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
+
from climateqa.engine.talk_to_data.main import ask_vanna
|
16 |
|
17 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
18 |
from front.utils import process_figures
|
19 |
+
from gradio_modal import Modal
|
20 |
|
21 |
|
22 |
from utils import create_user_id
|
|
|
69 |
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
|
70 |
|
71 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
72 |
+
if os.environ["GRADIO_ENV"] == "local":
|
73 |
reranker = get_reranker("nano")
|
74 |
+
else :
|
75 |
reranker = get_reranker("large")
|
76 |
|
77 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
|
|
95 |
|
96 |
# Function to update modal visibility
|
97 |
def update_config_modal_visibility(config_open):
|
98 |
+
print(config_open)
|
99 |
new_config_visibility_status = not config_open
|
100 |
+
return Modal(visible=new_config_visibility_status), new_config_visibility_status
|
101 |
|
102 |
|
103 |
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
|
|
|
113 |
|
114 |
return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
|
115 |
|
116 |
+
def create_drias_tab():
|
117 |
+
with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
|
118 |
+
vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
|
119 |
+
with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
|
120 |
+
vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
|
121 |
+
show_vanna_table = gr.Button("Show Table", elem_id="show-table")
|
122 |
+
with Modal(visible=False) as vanna_table_modal:
|
123 |
+
vanna_table = gr.DataFrame([], elem_id="vanna-table")
|
124 |
+
close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
|
125 |
+
close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
|
126 |
+
show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
|
127 |
+
|
128 |
+
vanna_display = gr.Plot()
|
129 |
+
vanna_direct_question.submit(ask_vanna, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
|
130 |
+
|
131 |
# # UI Layout Components
|
132 |
def cqa_tab(tab_name):
|
133 |
# State variables
|
|
|
159 |
|
160 |
# Papers subtab
|
161 |
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
162 |
+
papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
163 |
|
164 |
# Graphs subtab
|
165 |
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
|
|
167 |
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
|
168 |
elem_id="graphs-container"
|
169 |
)
|
170 |
+
|
171 |
+
|
172 |
return {
|
173 |
"chatbot": chatbot,
|
174 |
"textbox": textbox,
|
|
|
181 |
"figures_cards": figures_cards,
|
182 |
"gallery_component": gallery_component,
|
183 |
"config_button": config_button,
|
184 |
+
"papers_direct_search" : papers_direct_search,
|
185 |
"papers_html": papers_html,
|
186 |
"citations_network": citations_network,
|
187 |
"papers_summary": papers_summary,
|
|
|
190 |
"tab_figures": tab_figures,
|
191 |
"tab_graphs": tab_graphs,
|
192 |
"tab_papers": tab_papers,
|
193 |
+
"graph_container": graphs_container,
|
194 |
+
# "vanna_sql_query": vanna_sql_query,
|
195 |
+
# "vanna_table" : vanna_table,
|
196 |
+
# "vanna_display": vanna_display
|
197 |
}
|
198 |
|
199 |
+
def config_event_handling(main_tabs_components : list[dict], config_componenets : dict):
|
200 |
+
config_open = config_componenets["config_open"]
|
201 |
+
config_modal = config_componenets["config_modal"]
|
202 |
+
close_config_modal = config_componenets["close_config_modal_button"]
|
203 |
+
|
204 |
+
for button in [close_config_modal] + [main_tab_component["config_button"] for main_tab_component in main_tabs_components]:
|
205 |
+
button.click(
|
206 |
+
fn=update_config_modal_visibility,
|
207 |
+
inputs=[config_open],
|
208 |
+
outputs=[config_modal, config_open]
|
209 |
+
)
|
210 |
|
211 |
def event_handling(
|
212 |
main_tab_components,
|
|
|
223 |
sources_textbox = main_tab_components["sources_textbox"]
|
224 |
figures_cards = main_tab_components["figures_cards"]
|
225 |
gallery_component = main_tab_components["gallery_component"]
|
226 |
+
# config_button = main_tab_components["config_button"]
|
227 |
+
papers_direct_search = main_tab_components["papers_direct_search"]
|
228 |
papers_html = main_tab_components["papers_html"]
|
229 |
citations_network = main_tab_components["citations_network"]
|
230 |
papers_summary = main_tab_components["papers_summary"]
|
|
|
234 |
tab_graphs = main_tab_components["tab_graphs"]
|
235 |
tab_papers = main_tab_components["tab_papers"]
|
236 |
graphs_container = main_tab_components["graph_container"]
|
237 |
+
# vanna_sql_query = main_tab_components["vanna_sql_query"]
|
238 |
+
# vanna_table = main_tab_components["vanna_table"]
|
239 |
+
# vanna_display = main_tab_components["vanna_display"]
|
240 |
|
241 |
+
|
242 |
+
# config_open = config_components["config_open"]
|
243 |
+
# config_modal = config_components["config_modal"]
|
244 |
dropdown_sources = config_components["dropdown_sources"]
|
245 |
dropdown_reports = config_components["dropdown_reports"]
|
246 |
dropdown_external_sources = config_components["dropdown_external_sources"]
|
|
|
249 |
after = config_components["after"]
|
250 |
output_query = config_components["output_query"]
|
251 |
output_language = config_components["output_language"]
|
252 |
+
# close_config_modal = config_components["close_config_modal_button"]
|
253 |
|
254 |
new_sources_hmtl = gr.State([])
|
255 |
+
ttd_data = gr.State([])
|
256 |
+
|
257 |
|
258 |
+
# for button in [config_button, close_config_modal]:
|
259 |
+
# button.click(
|
260 |
+
# fn=update_config_modal_visibility,
|
261 |
+
# inputs=[config_open],
|
262 |
+
# outputs=[config_modal, config_open]
|
263 |
+
# )
|
264 |
|
265 |
if tab_name == "ClimateQ&A":
|
266 |
print("chat cqa - message sent")
|
|
|
303 |
component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
304 |
|
305 |
# Search for papers
|
306 |
+
for component in [textbox, examples_hidden, papers_direct_search]:
|
307 |
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
308 |
|
309 |
|
310 |
+
# if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
|
311 |
+
# # Drias search
|
312 |
+
# textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
|
313 |
|
314 |
def main_ui():
|
315 |
# config_open = gr.State(True)
|
|
|
319 |
with gr.Tabs():
|
320 |
cqa_components = cqa_tab(tab_name = "ClimateQ&A")
|
321 |
local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
|
322 |
+
create_drias_tab()
|
323 |
|
324 |
create_about_tab()
|
325 |
|
326 |
event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
|
327 |
+
event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action")
|
328 |
+
|
329 |
+
config_event_handling([cqa_components,local_cqa_components] ,config_components)
|
330 |
|
331 |
demo.queue()
|
332 |
|
climateqa/chat.py
CHANGED
@@ -53,6 +53,13 @@ def log_interaction_to_azure(history, output_query, sources, docs, share_client,
|
|
53 |
print(f"Error logging on Azure Blob Storage: {e}")
|
54 |
error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
|
55 |
raise gr.Error(error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# Main chat function
|
58 |
async def chat_stream(
|
@@ -121,6 +128,7 @@ async def chat_stream(
|
|
121 |
used_documents = []
|
122 |
retrieved_contents = []
|
123 |
answer_message_content = ""
|
|
|
124 |
|
125 |
# Define processing steps
|
126 |
steps_display = {
|
@@ -142,6 +150,14 @@ async def chat_stream(
|
|
142 |
history, used_documents, retrieved_contents = handle_retrieved_documents(
|
143 |
event, history, used_documents, retrieved_contents
|
144 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
|
146 |
docs = event["data"]["input"]["documents"]
|
147 |
docs_html = convert_to_docs_to_html(docs)
|
@@ -184,7 +200,7 @@ async def chat_stream(
|
|
184 |
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
185 |
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
186 |
|
187 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
188 |
|
189 |
except Exception as e:
|
190 |
print(f"Event {event} has failed")
|
@@ -195,4 +211,4 @@ async def chat_stream(
|
|
195 |
# Call the function to log interaction
|
196 |
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
|
197 |
|
198 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
|
|
53 |
print(f"Error logging on Azure Blob Storage: {e}")
|
54 |
error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
|
55 |
raise gr.Error(error_msg)
|
56 |
+
|
57 |
+
def handle_numerical_data(event):
|
58 |
+
if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
|
59 |
+
numerical_data = event["data"]["output"]["drias_data"]
|
60 |
+
sql_query = event["data"]["output"]["drias_sql_query"]
|
61 |
+
return numerical_data, sql_query
|
62 |
+
return None, None
|
63 |
|
64 |
# Main chat function
|
65 |
async def chat_stream(
|
|
|
128 |
used_documents = []
|
129 |
retrieved_contents = []
|
130 |
answer_message_content = ""
|
131 |
+
vanna_data = {}
|
132 |
|
133 |
# Define processing steps
|
134 |
steps_display = {
|
|
|
150 |
history, used_documents, retrieved_contents = handle_retrieved_documents(
|
151 |
event, history, used_documents, retrieved_contents
|
152 |
)
|
153 |
+
# Handle Vanna retrieval
|
154 |
+
# if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
|
155 |
+
# df_output_vanna, sql_query = handle_numerical_data(
|
156 |
+
# event
|
157 |
+
# )
|
158 |
+
# vanna_data = {"df_output": df_output_vanna, "sql_query": sql_query}
|
159 |
+
|
160 |
+
|
161 |
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
|
162 |
docs = event["data"]["input"]["documents"]
|
163 |
docs_html = convert_to_docs_to_html(docs)
|
|
|
200 |
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
201 |
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
202 |
|
203 |
+
yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
|
204 |
|
205 |
except Exception as e:
|
206 |
print(f"Event {event} has failed")
|
|
|
211 |
# Call the function to log interaction
|
212 |
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
|
213 |
|
214 |
+
yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -19,7 +19,7 @@ VANNA_MODEL = os.getenv('VANNA_MODEL')
|
|
19 |
|
20 |
#Vanna object
|
21 |
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
|
22 |
-
db_vanna_path = os.path.join(os.path.dirname(__file__), "
|
23 |
vn.connect_to_sqlite(db_vanna_path)
|
24 |
|
25 |
llm = get_llm(provider="openai")
|
|
|
19 |
|
20 |
#Vanna object
|
21 |
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
|
22 |
+
db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
|
23 |
vn.connect_to_sqlite(db_vanna_path)
|
24 |
|
25 |
llm = get_llm(provider="openai")
|
climateqa/engine/talk_to_data/myVanna.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
|
3 |
+
from vanna.openai import OpenAI_Chat
|
4 |
+
import os
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
9 |
+
|
10 |
+
class MyVanna(MyCustomVectorDB, OpenAI_Chat):
|
11 |
+
def __init__(self, config=None):
|
12 |
+
MyCustomVectorDB.__init__(self, config=config)
|
13 |
+
OpenAI_Chat.__init__(self, config=config)
|
climateqa/engine/talk_to_data/utils.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import openai
|
3 |
+
import pandas as pd
|
4 |
+
from geopy.geocoders import Nominatim
|
5 |
+
import sqlite3
|
6 |
+
import ast
|
7 |
+
|
8 |
+
|
9 |
+
def detect_location_with_openai(api_key, sentence):
|
10 |
+
"""
|
11 |
+
Detects locations in a sentence using OpenAI's API.
|
12 |
+
"""
|
13 |
+
openai.api_key = api_key
|
14 |
+
|
15 |
+
prompt = f"""
|
16 |
+
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
17 |
+
Return the result as a Python list. If no locations are mentioned, return an empty list.
|
18 |
+
|
19 |
+
Sentence: "{sentence}"
|
20 |
+
"""
|
21 |
+
|
22 |
+
response = openai.chat.completions.create(
|
23 |
+
model="gpt-4o-mini",
|
24 |
+
messages=[
|
25 |
+
{"role": "system", "content": "You are a helpful assistant skilled in identifying locations in text."},
|
26 |
+
{"role": "user", "content": prompt}
|
27 |
+
],
|
28 |
+
max_tokens=100,
|
29 |
+
temperature=0
|
30 |
+
)
|
31 |
+
|
32 |
+
return response.choices[0].message.content.split("\n")[1][2:-2]
|
33 |
+
|
34 |
+
|
35 |
+
def detectTable(sql_query):
|
36 |
+
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
37 |
+
matches = re.findall(pattern, sql_query)
|
38 |
+
return matches
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def loc2coords(location : str):
|
43 |
+
geolocator = Nominatim(user_agent="city_to_latlong")
|
44 |
+
location = geolocator.geocode(location)
|
45 |
+
return (location.latitude, location.longitude)
|
46 |
+
|
47 |
+
|
48 |
+
def coords2loc(coords : tuple):
|
49 |
+
geolocator = Nominatim(user_agent="coords_to_city")
|
50 |
+
try:
|
51 |
+
location = geolocator.reverse(coords)
|
52 |
+
return location.address
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error: {e}")
|
55 |
+
return "Unknown Location"
|
56 |
+
|
57 |
+
|
58 |
+
def nearestNeighbourSQL(db: str, location: tuple, table : str):
|
59 |
+
conn = sqlite3.connect(db)
|
60 |
+
long = round(location[1], 3)
|
61 |
+
lat = round(location[0], 3)
|
62 |
+
cursor = conn.cursor()
|
63 |
+
cursor.execute(f"SELECT lat, lon FROM {table} WHERE lat BETWEEN {lat - 0.3} AND {lat + 0.3} AND lon BETWEEN {long - 0.3} AND {long + 0.3}")
|
64 |
+
results = cursor.fetchall()
|
65 |
+
return results[0]
|
66 |
+
|
67 |
+
def detect_relevant_tables(user_question, llm):
|
68 |
+
table_names_list = [
|
69 |
+
"Frequency_of_rainy_days_index",
|
70 |
+
"Winter_precipitation_total",
|
71 |
+
"Summer_precipitation_total",
|
72 |
+
"Annual_precipitation_total",
|
73 |
+
# "Remarkable_daily_precipitation_total_(Q99)",
|
74 |
+
"Frequency_of_remarkable_daily_precipitation",
|
75 |
+
"Extreme_precipitation_intensity",
|
76 |
+
"Mean_winter_temperature",
|
77 |
+
"Mean_summer_temperature",
|
78 |
+
"Number_of_tropical_nights",
|
79 |
+
"Maximum_summer_temperature",
|
80 |
+
"Number_of_days_with_Tx_above_30C",
|
81 |
+
"Number_of_days_with_Tx_above_35C",
|
82 |
+
"Drought_index"
|
83 |
+
]
|
84 |
+
prompt = (
|
85 |
+
f"You are helping to build a sql query to retrieve relevant data for a user question."
|
86 |
+
f"The different tables are {table_names_list}."
|
87 |
+
f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
|
88 |
+
)
|
89 |
+
table_names = ast.literal_eval(llm.invoke(prompt).content.strip("```python\n").strip())
|
90 |
+
return table_names
|
91 |
+
|
92 |
+
def replace_coordonates(coords, query, coords_tables):
|
93 |
+
n = query.count(str(coords[0]))
|
94 |
+
|
95 |
+
for i in range(n):
|
96 |
+
query = query.replace(str(coords[0]), str(coords_tables[i][0]),1)
|
97 |
+
query = query.replace(str(coords[1]), str(coords_tables[i][1]),1)
|
98 |
+
return query
|
climateqa/engine/talk_to_data/vanna_class.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vanna.base import VannaBase
|
2 |
+
from pinecone import Pinecone
|
3 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
4 |
+
import pandas as pd
|
5 |
+
import hashlib
|
6 |
+
|
7 |
+
class MyCustomVectorDB(VannaBase):
|
8 |
+
|
9 |
+
"""
|
10 |
+
VectorDB class for storing and retrieving vectors from Pinecone.
|
11 |
+
|
12 |
+
args :
|
13 |
+
config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
|
14 |
+
- pc_api_key (str) : Pinecone API key
|
15 |
+
- index_name (str) : Pinecone index name
|
16 |
+
- top_k (int) : Number of top results to return (default = 2)
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,config):
|
21 |
+
super().__init__(config = config)
|
22 |
+
try :
|
23 |
+
self.api_key = config.get('pc_api_key')
|
24 |
+
self.index_name = config.get('index_name')
|
25 |
+
except :
|
26 |
+
raise Exception("Please provide the Pinecone API key and the index name")
|
27 |
+
|
28 |
+
self.pc = Pinecone(api_key = self.api_key)
|
29 |
+
self.index = self.pc.Index(self.index_name)
|
30 |
+
self.top_k = config.get('top_k', 2)
|
31 |
+
self.embeddings = get_embeddings_function()
|
32 |
+
|
33 |
+
|
34 |
+
def check_embedding(self, id, namespace):
|
35 |
+
fetched = self.index.fetch(ids = [id], namespace = namespace)
|
36 |
+
if fetched['vectors'] == {}:
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
def generate_hash_id(self, data: str) -> str:
|
41 |
+
"""
|
42 |
+
Generate a unique hash ID for the given data.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
data (str): The input data to hash (e.g., a concatenated string of user attributes).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: A unique hash ID as a hexadecimal string.
|
49 |
+
"""
|
50 |
+
|
51 |
+
data_bytes = data.encode('utf-8')
|
52 |
+
hash_object = hashlib.sha256(data_bytes)
|
53 |
+
hash_id = hash_object.hexdigest()
|
54 |
+
|
55 |
+
return hash_id
|
56 |
+
|
57 |
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
58 |
+
id = self.generate_hash_id(ddl) + '_ddl'
|
59 |
+
|
60 |
+
if self.check_embedding(id, 'ddl'):
|
61 |
+
print(f"DDL having id {id} already exists")
|
62 |
+
return id
|
63 |
+
|
64 |
+
self.index.upsert(
|
65 |
+
vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
|
66 |
+
namespace = 'ddl'
|
67 |
+
)
|
68 |
+
|
69 |
+
return id
|
70 |
+
|
71 |
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
72 |
+
id = self.generate_hash_id(doc) + '_doc'
|
73 |
+
|
74 |
+
if self.check_embedding(id, 'documentation'):
|
75 |
+
print(f"Documentation having id {id} already exists")
|
76 |
+
return id
|
77 |
+
|
78 |
+
self.index.upsert(
|
79 |
+
vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
|
80 |
+
namespace = 'documentation'
|
81 |
+
)
|
82 |
+
|
83 |
+
return id
|
84 |
+
|
85 |
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
86 |
+
id = self.generate_hash_id(question) + '_sql'
|
87 |
+
|
88 |
+
if self.check_embedding(id, 'question_sql'):
|
89 |
+
print(f"Question-SQL pair having id {id} already exists")
|
90 |
+
return id
|
91 |
+
|
92 |
+
self.index.upsert(
|
93 |
+
vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
|
94 |
+
namespace = 'question_sql'
|
95 |
+
)
|
96 |
+
|
97 |
+
return id
|
98 |
+
|
99 |
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
100 |
+
res = self.index.query(
|
101 |
+
vector=self.embeddings.embed_query(question),
|
102 |
+
top_k=self.top_k,
|
103 |
+
namespace='ddl',
|
104 |
+
include_metadata=True
|
105 |
+
)
|
106 |
+
|
107 |
+
return [match['metadata']['ddl'] for match in res['matches']]
|
108 |
+
|
109 |
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
110 |
+
res = self.index.query(
|
111 |
+
vector=self.embeddings.embed_query(question),
|
112 |
+
top_k=self.top_k,
|
113 |
+
namespace='documentation',
|
114 |
+
include_metadata=True
|
115 |
+
)
|
116 |
+
|
117 |
+
return [match['metadata']['doc'] for match in res['matches']]
|
118 |
+
|
119 |
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
120 |
+
res = self.index.query(
|
121 |
+
vector=self.embeddings.embed_query(question),
|
122 |
+
top_k=self.top_k,
|
123 |
+
namespace='question_sql',
|
124 |
+
include_metadata=True
|
125 |
+
)
|
126 |
+
|
127 |
+
return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
|
128 |
+
|
129 |
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
130 |
+
|
131 |
+
list_of_data = []
|
132 |
+
|
133 |
+
namespaces = ['ddl', 'documentation', 'question_sql']
|
134 |
+
|
135 |
+
for namespace in namespaces:
|
136 |
+
|
137 |
+
data = self.index.query(
|
138 |
+
top_k=10000,
|
139 |
+
namespace=namespace,
|
140 |
+
include_metadata=True,
|
141 |
+
include_values=False
|
142 |
+
)
|
143 |
+
|
144 |
+
for match in data['matches']:
|
145 |
+
list_of_data.append(match['metadata'])
|
146 |
+
|
147 |
+
return pd.DataFrame(list_of_data)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
152 |
+
if id.endswith("_ddl"):
|
153 |
+
self.Index.delete(ids=[id], namespace="_ddl")
|
154 |
+
return True
|
155 |
+
if id.endswith("_sql"):
|
156 |
+
self.index.delete(ids=[id], namespace="_sql")
|
157 |
+
return True
|
158 |
+
|
159 |
+
if id.endswith("_doc"):
|
160 |
+
self.Index.delete(ids=[id], namespace="_doc")
|
161 |
+
return True
|
162 |
+
|
163 |
+
return False
|
164 |
+
|
165 |
+
def generate_embedding(self, text, **kwargs):
|
166 |
+
# Implement the method here
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
+
def get_sql_prompt(
|
171 |
+
self,
|
172 |
+
initial_prompt : str,
|
173 |
+
question: str,
|
174 |
+
question_sql_list: list,
|
175 |
+
ddl_list: list,
|
176 |
+
doc_list: list,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
Example:
|
181 |
+
```python
|
182 |
+
vn.get_sql_prompt(
|
183 |
+
question="What are the top 10 customers by sales?",
|
184 |
+
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
185 |
+
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
186 |
+
doc_list=["The customers table contains information about customers and their sales."],
|
187 |
+
)
|
188 |
+
|
189 |
+
```
|
190 |
+
|
191 |
+
This method is used to generate a prompt for the LLM to generate SQL.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
question (str): The question to generate SQL for.
|
195 |
+
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
196 |
+
ddl_list (list): A list of DDL statements.
|
197 |
+
doc_list (list): A list of documentation.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
any: The prompt for the LLM to generate SQL.
|
201 |
+
"""
|
202 |
+
|
203 |
+
if initial_prompt is None:
|
204 |
+
initial_prompt = f"You are a {self.dialect} expert. " + \
|
205 |
+
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
206 |
+
|
207 |
+
initial_prompt = self.add_ddl_to_prompt(
|
208 |
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
209 |
+
)
|
210 |
+
|
211 |
+
if self.static_documentation != "":
|
212 |
+
doc_list.append(self.static_documentation)
|
213 |
+
|
214 |
+
initial_prompt = self.add_documentation_to_prompt(
|
215 |
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
216 |
+
)
|
217 |
+
|
218 |
+
# initial_prompt = self.add_sql_to_prompt(
|
219 |
+
# initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
220 |
+
# )
|
221 |
+
|
222 |
+
|
223 |
+
initial_prompt += (
|
224 |
+
"===Response Guidelines \n"
|
225 |
+
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
226 |
+
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
227 |
+
"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
|
228 |
+
"4. Please use the most relevant table(s). \n"
|
229 |
+
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
230 |
+
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
231 |
+
f"7. Add a description of the table in the result of the sql query, if relevant. \n"
|
232 |
+
"8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
|
233 |
+
# f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
|
234 |
+
# "7. Add a description of the table in the result of the sql query."
|
235 |
+
# "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
|
236 |
+
# "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
message_log = [self.system_message(initial_prompt)]
|
241 |
+
|
242 |
+
for example in question_sql_list:
|
243 |
+
if example is None:
|
244 |
+
print("example is None")
|
245 |
+
else:
|
246 |
+
if example is not None and "question" in example and "sql" in example:
|
247 |
+
message_log.append(self.user_message(example["question"]))
|
248 |
+
message_log.append(self.assistant_message(example["sql"]))
|
249 |
+
|
250 |
+
message_log.append(self.user_message(question))
|
251 |
+
|
252 |
+
return message_log
|
253 |
+
|
254 |
+
|
255 |
+
# def get_sql_prompt(
|
256 |
+
# self,
|
257 |
+
# initial_prompt : str,
|
258 |
+
# question: str,
|
259 |
+
# question_sql_list: list,
|
260 |
+
# ddl_list: list,
|
261 |
+
# doc_list: list,
|
262 |
+
# **kwargs,
|
263 |
+
# ):
|
264 |
+
# """
|
265 |
+
# Example:
|
266 |
+
# ```python
|
267 |
+
# vn.get_sql_prompt(
|
268 |
+
# question="What are the top 10 customers by sales?",
|
269 |
+
# question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
270 |
+
# ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
271 |
+
# doc_list=["The customers table contains information about customers and their sales."],
|
272 |
+
# )
|
273 |
+
|
274 |
+
# ```
|
275 |
+
|
276 |
+
# This method is used to generate a prompt for the LLM to generate SQL.
|
277 |
+
|
278 |
+
# Args:
|
279 |
+
# question (str): The question to generate SQL for.
|
280 |
+
# question_sql_list (list): A list of questions and their corresponding SQL statements.
|
281 |
+
# ddl_list (list): A list of DDL statements.
|
282 |
+
# doc_list (list): A list of documentation.
|
283 |
+
|
284 |
+
# Returns:
|
285 |
+
# any: The prompt for the LLM to generate SQL.
|
286 |
+
# """
|
287 |
+
|
288 |
+
# if initial_prompt is None:
|
289 |
+
# initial_prompt = f"You are a {self.dialect} expert. " + \
|
290 |
+
# "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
291 |
+
|
292 |
+
# initial_prompt = self.add_ddl_to_prompt(
|
293 |
+
# initial_prompt, ddl_list, max_tokens=self.max_tokens
|
294 |
+
# )
|
295 |
+
|
296 |
+
# if self.static_documentation != "":
|
297 |
+
# doc_list.append(self.static_documentation)
|
298 |
+
|
299 |
+
# initial_prompt = self.add_documentation_to_prompt(
|
300 |
+
# initial_prompt, doc_list, max_tokens=self.max_tokens
|
301 |
+
# )
|
302 |
+
|
303 |
+
# initial_prompt += (
|
304 |
+
# "===Response Guidelines \n"
|
305 |
+
# "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
306 |
+
# "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
307 |
+
# "3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
308 |
+
# "4. Please use the most relevant table(s). \n"
|
309 |
+
# "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
310 |
+
# f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
311 |
+
# )
|
312 |
+
|
313 |
+
# message_log = [self.system_message(initial_prompt)]
|
314 |
+
|
315 |
+
# for example in question_sql_list:
|
316 |
+
# if example is None:
|
317 |
+
# print("example is None")
|
318 |
+
# else:
|
319 |
+
# if example is not None and "question" in example and "sql" in example:
|
320 |
+
# message_log.append(self.user_message(example["question"]))
|
321 |
+
# message_log.append(self.assistant_message(example["sql"]))
|
322 |
+
|
323 |
+
# message_log.append(self.user_message(question))
|
324 |
+
|
325 |
+
# return message_log
|
front/tabs/tab_papers.py
CHANGED
@@ -3,6 +3,8 @@ from gradio_modal import Modal
|
|
3 |
|
4 |
|
5 |
def create_papers_tab():
|
|
|
|
|
6 |
with gr.Accordion(
|
7 |
visible=True,
|
8 |
elem_id="papers-summary-popup",
|
@@ -32,5 +34,5 @@ def create_papers_tab():
|
|
32 |
papers_modal
|
33 |
)
|
34 |
|
35 |
-
return papers_summary, papers_html, citations_network, papers_modal
|
36 |
|
|
|
3 |
|
4 |
|
5 |
def create_papers_tab():
|
6 |
+
direct_search_textbox = gr.Textbox(label="Direct search for papers", placeholder= "What is climate change ?", elem_id="papers-search")
|
7 |
+
|
8 |
with gr.Accordion(
|
9 |
visible=True,
|
10 |
elem_id="papers-summary-popup",
|
|
|
34 |
papers_modal
|
35 |
)
|
36 |
|
37 |
+
return direct_search_textbox, papers_summary, papers_html, citations_network, papers_modal
|
38 |
|
requirements.txt
CHANGED
@@ -19,3 +19,5 @@ langchain-community==0.2
|
|
19 |
msal==1.31
|
20 |
matplotlib==3.9.2
|
21 |
gradio-modal==0.0.4
|
|
|
|
|
|
19 |
msal==1.31
|
20 |
matplotlib==3.9.2
|
21 |
gradio-modal==0.0.4
|
22 |
+
vanna==0.7.5
|
23 |
+
geopy==2.4.1
|
sandbox/talk_to_data/20250306 - CQA - Drias.ipynb
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Import the function in main.py"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import sys\n",
|
17 |
+
"import os\n",
|
18 |
+
"sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
|
19 |
+
"\n",
|
20 |
+
"%load_ext autoreload\n",
|
21 |
+
"%autoreload 2\n",
|
22 |
+
"\n",
|
23 |
+
"from climateqa.engine.talk_to_data.main import ask_vanna\n"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"## Create a human query"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"query = \"Comment vont évoluer les températures à marseille ?\""
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "markdown",
|
44 |
+
"metadata": {},
|
45 |
+
"source": [
|
46 |
+
"## Call the function ask vanna, it gives an output of a the sql query and the dataframe of the result (tuple)"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": null,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"sql_query, df, fig = ask_vanna(query)\n",
|
56 |
+
"print(df.head())\n",
|
57 |
+
"fig.show()"
|
58 |
+
]
|
59 |
+
}
|
60 |
+
],
|
61 |
+
"metadata": {
|
62 |
+
"kernelspec": {
|
63 |
+
"display_name": "climateqa",
|
64 |
+
"language": "python",
|
65 |
+
"name": "python3"
|
66 |
+
},
|
67 |
+
"language_info": {
|
68 |
+
"codemirror_mode": {
|
69 |
+
"name": "ipython",
|
70 |
+
"version": 3
|
71 |
+
},
|
72 |
+
"file_extension": ".py",
|
73 |
+
"mimetype": "text/x-python",
|
74 |
+
"name": "python",
|
75 |
+
"nbconvert_exporter": "python",
|
76 |
+
"pygments_lexer": "ipython3",
|
77 |
+
"version": "3.11.9"
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"nbformat": 4,
|
81 |
+
"nbformat_minor": 2
|
82 |
+
}
|
sandbox/talk_to_data/20250306 - CQA - Step_by_step_vanna.ipynb
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import sys\n",
|
10 |
+
"import os\n",
|
11 |
+
"sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
|
12 |
+
"\n",
|
13 |
+
"%load_ext autoreload\n",
|
14 |
+
"%autoreload 2\n",
|
15 |
+
"\n",
|
16 |
+
"from climateqa.engine.talk_to_data.main import ask_vanna\n",
|
17 |
+
"\n",
|
18 |
+
"import sqlite3\n",
|
19 |
+
"import os\n",
|
20 |
+
"import pandas as pd"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "markdown",
|
25 |
+
"metadata": {},
|
26 |
+
"source": [
|
27 |
+
"# Imports"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": null,
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"from climateqa.engine.talk_to_data.myVanna import MyVanna\n",
|
37 |
+
"from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates#,nearestNeighbourPostgres\n",
|
38 |
+
"\n",
|
39 |
+
"from climateqa.engine.llm import get_llm"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "markdown",
|
44 |
+
"metadata": {},
|
45 |
+
"source": [
|
46 |
+
"# Vanna Ask\n"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": null,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"from dotenv import load_dotenv\n",
|
56 |
+
"\n",
|
57 |
+
"load_dotenv()\n",
|
58 |
+
"\n",
|
59 |
+
"llm = get_llm(provider=\"openai\")\n",
|
60 |
+
"\n",
|
61 |
+
"OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')\n",
|
62 |
+
"PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')\n",
|
63 |
+
"INDEX_NAME = os.getenv('VANNA_INDEX_NAME')\n",
|
64 |
+
"VANNA_MODEL = os.getenv('VANNA_MODEL')\n",
|
65 |
+
"\n",
|
66 |
+
"ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))\n",
|
67 |
+
"\n",
|
68 |
+
"#Vanna object\n",
|
69 |
+
"vn = MyVanna(config = {\"temperature\": 0, \"api_key\": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, \"top_k\" : 4})\n",
|
70 |
+
"\n",
|
71 |
+
"db_vanna_path = ROOT_PATH + \"/data/drias/drias.db\"\n",
|
72 |
+
"vn.connect_to_sqlite(db_vanna_path)\n"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "markdown",
|
77 |
+
"metadata": {},
|
78 |
+
"source": [
|
79 |
+
"# User query"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"query = \"Quelle sera la température à Marseille sur les prochaines années ?\""
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "markdown",
|
93 |
+
"metadata": {},
|
94 |
+
"source": [
|
95 |
+
"## Detect location"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": null,
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"location = detect_location_with_openai(OPENAI_API_KEY, query)\n",
|
105 |
+
"print(location)"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "markdown",
|
110 |
+
"metadata": {},
|
111 |
+
"source": [
|
112 |
+
"## Convert location to longitude, latitude coordonate"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"coords = loc2coords(location)\n",
|
122 |
+
"user_input = query.lower().replace(location.lower(), f\"lat, long : {coords}\")\n",
|
123 |
+
"print(user_input)"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "markdown",
|
128 |
+
"metadata": {},
|
129 |
+
"source": [
|
130 |
+
"# Find closest coordonates and replace lat,lon\n"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": null,
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"relevant_tables = detect_relevant_tables(user_input, llm) \n",
|
140 |
+
"coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]\n",
|
141 |
+
"user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)\n",
|
142 |
+
"print(user_input_with_coords)"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "markdown",
|
147 |
+
"metadata": {},
|
148 |
+
"source": [
|
149 |
+
"# Ask Vanna with correct coordonates"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": null,
|
155 |
+
"metadata": {},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"user_input_with_coords"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"metadata": {},
|
165 |
+
"outputs": [],
|
166 |
+
"source": [
|
167 |
+
"sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)\n",
|
168 |
+
"print(result_dataframe.head())"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"result_dataframe"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [],
|
185 |
+
"source": [
|
186 |
+
"figure"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": null,
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": []
|
195 |
+
}
|
196 |
+
],
|
197 |
+
"metadata": {
|
198 |
+
"kernelspec": {
|
199 |
+
"display_name": "climateqa",
|
200 |
+
"language": "python",
|
201 |
+
"name": "python3"
|
202 |
+
},
|
203 |
+
"language_info": {
|
204 |
+
"codemirror_mode": {
|
205 |
+
"name": "ipython",
|
206 |
+
"version": 3
|
207 |
+
},
|
208 |
+
"file_extension": ".py",
|
209 |
+
"mimetype": "text/x-python",
|
210 |
+
"name": "python",
|
211 |
+
"nbconvert_exporter": "python",
|
212 |
+
"pygments_lexer": "ipython3",
|
213 |
+
"version": "3.11.9"
|
214 |
+
}
|
215 |
+
},
|
216 |
+
"nbformat": 4,
|
217 |
+
"nbformat_minor": 2
|
218 |
+
}
|
style.css
CHANGED
@@ -481,14 +481,13 @@ a {
|
|
481 |
max-height: calc(100vh - 190px) !important;
|
482 |
overflow: hidden;
|
483 |
}
|
484 |
-
|
485 |
div#tab-examples,
|
486 |
div#sources-textbox,
|
487 |
div#tab-config {
|
488 |
height: calc(100vh - 190px) !important;
|
489 |
overflow-y: scroll !important;
|
490 |
}
|
491 |
-
|
492 |
div#sources-figures,
|
493 |
div#graphs-container,
|
494 |
div#tab-citations {
|
@@ -606,3 +605,16 @@ a {
|
|
606 |
#checkbox-config:checked {
|
607 |
display: block;
|
608 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
max-height: calc(100vh - 190px) !important;
|
482 |
overflow: hidden;
|
483 |
}
|
|
|
484 |
div#tab-examples,
|
485 |
div#sources-textbox,
|
486 |
div#tab-config {
|
487 |
height: calc(100vh - 190px) !important;
|
488 |
overflow-y: scroll !important;
|
489 |
}
|
490 |
+
div#tab-vanna,
|
491 |
div#sources-figures,
|
492 |
div#graphs-container,
|
493 |
div#tab-citations {
|
|
|
605 |
#checkbox-config:checked {
|
606 |
display: block;
|
607 |
}
|
608 |
+
|
609 |
+
#vanna-display {
|
610 |
+
max-height: 300px;
|
611 |
+
/* overflow-y: scroll; */
|
612 |
+
}
|
613 |
+
#sql-query{
|
614 |
+
max-height: 100px;
|
615 |
+
overflow-y:scroll;
|
616 |
+
}
|
617 |
+
#vanna-details{
|
618 |
+
max-height: 500px;
|
619 |
+
overflow-y:scroll;
|
620 |
+
}
|