timeki commited on
Commit
668db15
·
1 Parent(s): f0e965f

working version of vanna

Browse files
.gitignore CHANGED
@@ -15,4 +15,5 @@ sandbox/
15
  climateqa/talk_to_data/database/
16
  *.db
17
 
18
- data_ingestion/
 
 
15
  climateqa/talk_to_data/database/
16
  *.db
17
 
18
+ data_ingestion/
19
+ .vscode
app.py CHANGED
@@ -141,7 +141,7 @@ def cqa_tab(tab_name):
141
  "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
142
  elem_id="graphs-container"
143
  )
144
- with gr.Tab("Vanna", elem_id="tab-vanna", id=6) as tab_vanna:
145
  vanna_table = gr.DataFrame([], elem_id="vanna-display")
146
  vanna_display = gr.Plot()
147
 
@@ -226,13 +226,13 @@ def event_handling(
226
  # Event for textbox
227
  (textbox
228
  .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
229
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, ttd_data], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
230
  .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
231
  )
232
  # Event for examples_hidden
233
  (examples_hidden
234
  .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
235
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, ttd_data], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
236
  .then(finish_chat, None, [examples_hidden], api_name=f"finish_chat_{examples_hidden.elem_id}")
237
  )
238
 
@@ -249,7 +249,7 @@ def event_handling(
249
  for component in [textbox, examples_hidden]:
250
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
251
 
252
- # ttd_data.change(lambda x: x["df_output"], inputs=[ttd_data], outputs=[vanna_display])
253
  textbox.submit(ask_vanna, [textbox], [vanna_table, vanna_display])
254
 
255
  def main_ui():
 
141
  "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
142
  elem_id="graphs-container"
143
  )
144
+ with gr.Tab("DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
145
  vanna_table = gr.DataFrame([], elem_id="vanna-display")
146
  vanna_display = gr.Plot()
147
 
 
226
  # Event for textbox
227
  (textbox
228
  .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
229
+ .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
230
  .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
231
  )
232
  # Event for examples_hidden
233
  (examples_hidden
234
  .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
235
+ .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
236
  .then(finish_chat, None, [examples_hidden], api_name=f"finish_chat_{examples_hidden.elem_id}")
237
  )
238
 
 
249
  for component in [textbox, examples_hidden]:
250
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
251
 
252
+ # Drias search
253
  textbox.submit(ask_vanna, [textbox], [vanna_table, vanna_display])
254
 
255
  def main_ui():
climateqa/chat.py CHANGED
@@ -58,7 +58,8 @@ def handle_numerical_data(event):
58
  if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
59
  numerical_data = event["data"]["output"]["drias_data"]
60
  sql_query = event["data"]["output"]["drias_sql_query"]
61
- return numerical_data, sql_query
 
62
 
63
  # Main chat function
64
  async def chat_stream(
@@ -148,12 +149,12 @@ async def chat_stream(
148
  history, used_documents = handle_retrieved_documents(
149
  event, history, used_documents
150
  )
151
- # Handle document retrieval
152
- if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
153
- df_output_vanna, sql_query = handle_numerical_data(
154
- event
155
- )
156
- vanna_data = {"df_output": df_output_vanna, "sql_query": sql_query}
157
 
158
 
159
  if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
@@ -198,7 +199,7 @@ async def chat_stream(
198
  sub_questions = [q["question"] for q in event["data"]["output"]["questions_list"]]
199
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
200
 
201
- yield history, docs_html, output_query, output_language, related_contents, graphs_html, vanna_data
202
 
203
  except Exception as e:
204
  print(f"Event {event} has failed")
@@ -209,4 +210,4 @@ async def chat_stream(
209
  # Call the function to log interaction
210
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
211
 
212
- yield history, docs_html, output_query, output_language, related_contents, graphs_html, vanna_data
 
58
  if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
59
  numerical_data = event["data"]["output"]["drias_data"]
60
  sql_query = event["data"]["output"]["drias_sql_query"]
61
+ return numerical_data, sql_query
62
+ return None, None
63
 
64
  # Main chat function
65
  async def chat_stream(
 
149
  history, used_documents = handle_retrieved_documents(
150
  event, history, used_documents
151
  )
152
+ # Handle Vanna retrieval
153
+ # if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
154
+ # df_output_vanna, sql_query = handle_numerical_data(
155
+ # event
156
+ # )
157
+ # vanna_data = {"df_output": df_output_vanna, "sql_query": sql_query}
158
 
159
 
160
  if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
 
199
  sub_questions = [q["question"] for q in event["data"]["output"]["questions_list"]]
200
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
201
 
202
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
203
 
204
  except Exception as e:
205
  print(f"Event {event} has failed")
 
210
  # Call the function to log interaction
211
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
212
 
213
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
climateqa/engine/chains/intent_categorization.py CHANGED
@@ -57,6 +57,7 @@ def make_intent_categorization_node(llm):
57
  categorization_chain = make_intent_categorization_chain(llm)
58
 
59
  def categorize_message(state):
 
60
  print("---- Categorize_message ----")
61
 
62
  output = categorization_chain.invoke({"input": state["user_input"]})
 
57
  categorization_chain = make_intent_categorization_chain(llm)
58
 
59
  def categorize_message(state):
60
+ print("Input Message : ", state["user_input"])
61
  print("---- Categorize_message ----")
62
 
63
  output = categorization_chain.invoke({"input": state["user_input"]})
climateqa/engine/chains/query_transformation.py CHANGED
@@ -293,6 +293,8 @@ def make_query_transform_node(llm,k_final=15):
293
  "n_questions":n_questions,
294
  "handled_questions_index":[],
295
  }
 
 
296
  return new_state
297
 
298
  return transform_query
 
293
  "n_questions":n_questions,
294
  "handled_questions_index":[],
295
  }
296
+ print("New questions")
297
+ print(new_questions)
298
  return new_state
299
 
300
  return transform_query
climateqa/engine/graph.py CHANGED
@@ -75,7 +75,7 @@ def route_intent(state):
75
  def chitchat_route_intent(state):
76
  intent = state["search_graphs_chitchat"]
77
  if intent is True:
78
- return "retrieve_graphs_chitchat"
79
  elif intent is False:
80
  return END
81
 
 
75
  def chitchat_route_intent(state):
76
  intent = state["search_graphs_chitchat"]
77
  if intent is True:
78
+ return END #TODO
79
  elif intent is False:
80
  return END
81
 
climateqa/engine/talk_to_data/deprecated_vanna_remote.py DELETED
@@ -1,167 +0,0 @@
1
- # from vanna.remote import VannaDefault
2
- # from pinecone import Pinecone
3
- # from climateqa.engine.embeddings import get_embeddings_function
4
- # import pandas as pd
5
- # import hashlib
6
-
7
- # class MyCustomVectorDB(VannaDefault):
8
-
9
- # """
10
- # VectorDB class for storing and retrieving vectors from Pinecone.
11
-
12
- # args :
13
- # config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
14
- # - pc_api_key (str) : Pinecone API key
15
- # - index_name (str) : Pinecone index name
16
- # - top_k (int) : Number of top results to return (default = 2)
17
-
18
- # """
19
-
20
- # def __init__(self,config, **kwargs):
21
- # super().__init__(**kwargs)
22
- # try :
23
- # self.api_key = config.get('pc_api_key')
24
- # self.index_name = config.get('index_name')
25
- # except :
26
- # raise Exception("Please provide the Pinecone API key and the index name")
27
-
28
- # self.pc = Pinecone(api_key = self.api_key)
29
- # self.index = self.pc.Index(self.index_name)
30
- # self.top_k = config.get('top_k', 2)
31
- # self.embeddings = get_embeddings_function()
32
-
33
-
34
- # def check_embedding(self, id, namespace):
35
- # fetched = self.index.fetch(ids = [id], namespace = namespace)
36
- # if fetched['vectors'] == {}:
37
- # return False
38
- # return True
39
-
40
- # def generate_hash_id(self, data: str) -> str:
41
- # """
42
- # Generate a unique hash ID for the given data.
43
-
44
- # Args:
45
- # data (str): The input data to hash (e.g., a concatenated string of user attributes).
46
-
47
- # Returns:
48
- # str: A unique hash ID as a hexadecimal string.
49
- # """
50
-
51
- # data_bytes = data.encode('utf-8')
52
- # hash_object = hashlib.sha256(data_bytes)
53
- # hash_id = hash_object.hexdigest()
54
-
55
- # return hash_id
56
-
57
- # def add_ddl(self, ddl: str, **kwargs) -> str:
58
- # id = self.generate_hash_id(ddl) + '_ddl'
59
-
60
- # if self.check_embedding(id, 'ddl'):
61
- # print(f"DDL having id {id} already exists")
62
- # return id
63
-
64
- # self.index.upsert(
65
- # vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
66
- # namespace = 'ddl'
67
- # )
68
-
69
- # return id
70
-
71
- # def add_documentation(self, doc: str, **kwargs) -> str:
72
- # id = self.generate_hash_id(doc) + '_doc'
73
-
74
- # if self.check_embedding(id, 'documentation'):
75
- # print(f"Documentation having id {id} already exists")
76
- # return id
77
-
78
- # self.index.upsert(
79
- # vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
80
- # namespace = 'documentation'
81
- # )
82
-
83
- # return id
84
-
85
- # def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
86
- # id = self.generate_hash_id(question) + '_sql'
87
-
88
- # if self.check_embedding(id, 'question_sql'):
89
- # print(f"Question-SQL pair having id {id} already exists")
90
- # return id
91
-
92
- # self.index.upsert(
93
- # vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
94
- # namespace = 'question_sql'
95
- # )
96
-
97
- # return id
98
-
99
- # def get_related_ddl(self, question: str, **kwargs) -> list:
100
- # res = self.index.query(
101
- # vector=self.embeddings.embed_query(question),
102
- # top_k=self.top_k,
103
- # namespace='ddl',
104
- # include_metadata=True
105
- # )
106
-
107
- # print([match['metadata']['ddl'] for match in res['matches']])
108
-
109
- # return [match['metadata']['ddl'] for match in res['matches']]
110
-
111
- # def get_related_documentation(self, question: str, **kwargs) -> list:
112
- # res = self.index.query(
113
- # vector=self.embeddings.embed_query(question),
114
- # top_k=self.top_k,
115
- # namespace='documentation',
116
- # include_metadata=True
117
- # )
118
-
119
- # return [match['metadata']['doc'] for match in res['matches']]
120
-
121
- # def get_similar_quetion_sql(self, question: str, **kwargs) -> list:
122
- # res = self.index.query(
123
- # vector=self.embeddings.embed_query(question),
124
- # top_k=self.top_k,
125
- # namespace='question_sql',
126
- # include_metadata=True
127
- # )
128
-
129
- # return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
130
-
131
- # def get_training_data(self, **kwargs) -> pd.DataFrame:
132
-
133
- # list_of_data = []
134
-
135
- # namespaces = ['ddl', 'documentation', 'question_sql']
136
-
137
- # for namespace in namespaces:
138
-
139
- # data = self.index.query(
140
- # top_k=10000,
141
- # namespace=namespace,
142
- # include_metadata=True,
143
- # include_values=False
144
- # )
145
-
146
- # for match in data['matches']:
147
- # list_of_data.append(match['metadata'])
148
-
149
- # return pd.DataFrame(list_of_data)
150
-
151
-
152
-
153
- # def remove_training_data(self, id: str, **kwargs) -> bool:
154
- # if id.endswith("_ddl"):
155
- # self.Index.delete(ids=[id], namespace="_ddl")
156
- # return True
157
- # if id.endswith("_sql"):
158
- # self.index.delete(ids=[id], namespace="_sql")
159
- # return True
160
-
161
- # if id.endswith("_doc"):
162
- # self.Index.delete(ids=[id], namespace="_doc")
163
- # return True
164
-
165
- # return False
166
-
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/main.py CHANGED
@@ -4,6 +4,7 @@ import sqlite3
4
  import os
5
  import pandas as pd
6
  from climateqa.engine.llm import get_llm
 
7
 
8
  from dotenv import load_dotenv
9
 
@@ -17,7 +18,7 @@ VANNA_MODEL = os.getenv('VANNA_MODEL')
17
 
18
 
19
  #Vanna object
20
- vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME})
21
  db_vanna_path = os.path.join(os.path.dirname(__file__), "database/drias.db")
22
  vn.connect_to_sqlite(db_vanna_path)
23
 
@@ -68,13 +69,11 @@ def ask_vanna(query):
68
  sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\n").content
69
  print("execute sql query : ", sql_with_table_names)
70
  db = sqlite3.connect(db_vanna_path)
71
- # if "lat" not in sql_with_table_names:
72
- # sql_with_table_names = sql_with_table_names.replace("SELECT", "SELECT lat, lon,")
73
- # result = db.cursor().execute(sql_with_table_names).fetchall()
74
  result = db.cursor().execute(sql_query_new_coords).fetchall()
75
- # df = pd.DataFrame(result, columns = list(result_dataframe.columns))
76
- # df = pd.DataFrame(result, columns=["data_name"] + list(result_dataframe.columns))
77
- df = pd.DataFrame(result)
 
78
 
79
  plotly_code = vn.generate_plotly_code(
80
  question="query",
 
4
  import os
5
  import pandas as pd
6
  from climateqa.engine.llm import get_llm
7
+ import ast
8
 
9
  from dotenv import load_dotenv
10
 
 
18
 
19
 
20
  #Vanna object
21
+ vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
22
  db_vanna_path = os.path.join(os.path.dirname(__file__), "database/drias.db")
23
  vn.connect_to_sqlite(db_vanna_path)
24
 
 
69
  sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query_new_coords}. Just answer the query. The answer should not include ```sql\n").content
70
  print("execute sql query : ", sql_with_table_names)
71
  db = sqlite3.connect(db_vanna_path)
 
 
 
72
  result = db.cursor().execute(sql_query_new_coords).fetchall()
73
+ columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query_new_coords}").content
74
+ columns_list = ast.literal_eval(columns.strip("```python\n").strip())
75
+ print("column list : ",columns_list)
76
+ df = pd.DataFrame(result, columns=columns_list)
77
 
78
  plotly_code = vn.generate_plotly_code(
79
  question="query",
climateqa/engine/talk_to_data/step_by_step_vanna copy.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
climateqa/engine/talk_to_data/vanna_class.py CHANGED
@@ -228,6 +228,7 @@ class MyCustomVectorDB(VannaBase):
228
  "4. Please use the most relevant table(s). \n"
229
  "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
230
  f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
 
231
  # "7. Add a description of the table in the result of the sql query."
232
  # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
233
  # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
 
228
  "4. Please use the most relevant table(s). \n"
229
  "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
230
  f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
231
+ f"7. Add a description of the table in the result of the sql query, and latitude, logitude if relevant. \n"
232
  # "7. Add a description of the table in the result of the sql query."
233
  # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
234
  # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
style.css CHANGED
@@ -606,3 +606,8 @@ a {
606
  #checkbox-config:checked {
607
  display: block;
608
  }
 
 
 
 
 
 
606
  #checkbox-config:checked {
607
  display: block;
608
  }
609
+
610
+ #vanna-display {
611
+ height: 400px;
612
+ overflow-y: auto;
613
+ }