timeki commited on
Commit
8bd064f
·
1 Parent(s): e5c9448

update OpenAI usage from Vanna

Browse files
app.py CHANGED
@@ -13,6 +13,7 @@ from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
  from climateqa.engine.talk_to_data.main import ask_vanna
 
16
 
17
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
18
  from front.utils import process_figures
@@ -77,6 +78,14 @@ else :
77
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
78
  agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
79
 
 
 
 
 
 
 
 
 
80
 
81
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
82
  print("chat cqa - message received")
@@ -126,7 +135,7 @@ def create_drias_tab():
126
  show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
127
 
128
  vanna_display = gr.Plot()
129
- vanna_direct_question.submit(ask_vanna, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
130
 
131
  # # UI Layout Components
132
  def cqa_tab(tab_name):
 
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
  from climateqa.engine.talk_to_data.main import ask_vanna
16
+ from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
19
  from front.utils import process_figures
 
78
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
79
  agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
80
 
81
+ #Vanna object
82
+
83
+ vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
84
+ db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
85
+ vn.connect_to_sqlite(db_vanna_path)
86
+
87
+ def ask_vanna_query(query):
88
+ return ask_vanna(vn, db_vanna_path, query)
89
 
90
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
91
  print("chat cqa - message received")
 
135
  show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
136
 
137
  vanna_display = gr.Plot()
138
+ vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
139
 
140
  # # UI Layout Components
141
  def cqa_tab(tab_name):
climateqa/engine/talk_to_data/main.py CHANGED
@@ -4,24 +4,10 @@ import sqlite3
4
  import os
5
  import pandas as pd
6
  from climateqa.engine.llm import get_llm
7
-
8
- from dotenv import load_dotenv
9
  import ast
10
 
11
- load_dotenv()
12
-
13
-
14
- OPENAI_API_KEY = os.getenv('THEO_API_KEY')
15
- PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
16
- INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
17
- VANNA_MODEL = os.getenv('VANNA_MODEL')
18
 
19
 
20
- #Vanna object
21
- vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
22
- db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
23
- vn.connect_to_sqlite(db_vanna_path)
24
-
25
  llm = get_llm(provider="openai")
26
 
27
  def ask_llm_to_add_table_names(sql_query, llm):
@@ -33,9 +19,10 @@ def ask_llm_column_names(sql_query, llm):
33
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
34
  return columns_list
35
 
36
- def ask_vanna(query):
 
37
  try :
38
- location = detect_location_with_openai(OPENAI_API_KEY, query)
39
  if location:
40
 
41
  coords = loc2coords(location)
@@ -51,10 +38,10 @@ def ask_vanna(query):
51
 
52
  else :
53
  empty_df = pd.DataFrame()
54
- empty_fig = {}
55
  return "", empty_df, empty_fig
56
  except Exception as e:
57
  print(f"Error: {e}")
58
  empty_df = pd.DataFrame()
59
- empty_fig = {}
60
  return "", empty_df, empty_fig
 
4
  import os
5
  import pandas as pd
6
  from climateqa.engine.llm import get_llm
 
 
7
  import ast
8
 
 
 
 
 
 
 
 
9
 
10
 
 
 
 
 
 
11
  llm = get_llm(provider="openai")
12
 
13
  def ask_llm_to_add_table_names(sql_query, llm):
 
19
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
20
  return columns_list
21
 
22
+ def ask_vanna(vn,db_vanna_path, query):
23
+
24
  try :
25
+ location = detect_location_with_openai(query)
26
  if location:
27
 
28
  coords = loc2coords(location)
 
38
 
39
  else :
40
  empty_df = pd.DataFrame()
41
+ empty_fig = None
42
  return "", empty_df, empty_fig
43
  except Exception as e:
44
  print(f"Error: {e}")
45
  empty_df = pd.DataFrame()
46
+ empty_fig = None
47
  return "", empty_df, empty_fig
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -4,13 +4,13 @@ import pandas as pd
4
  from geopy.geocoders import Nominatim
5
  import sqlite3
6
  import ast
 
7
 
8
-
9
- def detect_location_with_openai(api_key, sentence):
10
  """
11
- Detects locations in a sentence using OpenAI's API.
12
  """
13
- openai.api_key = api_key
14
 
15
  prompt = f"""
16
  Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
@@ -19,18 +19,12 @@ def detect_location_with_openai(api_key, sentence):
19
  Sentence: "{sentence}"
20
  """
21
 
22
- response = openai.chat.completions.create(
23
- model="gpt-4o-mini",
24
- messages=[
25
- {"role": "system", "content": "You are a helpful assistant skilled in identifying locations in text."},
26
- {"role": "user", "content": prompt}
27
- ],
28
- max_tokens=100,
29
- temperature=0
30
- )
31
-
32
- return response.choices[0].message.content.split("\n")[1][2:-2]
33
-
34
 
35
  def detectTable(sql_query):
36
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
 
4
  from geopy.geocoders import Nominatim
5
  import sqlite3
6
  import ast
7
+ from climateqa.engine.llm import get_llm
8
 
9
+ def detect_location_with_openai(sentence):
 
10
  """
11
+ Detects locations in a sentence using OpenAI's API via LangChain.
12
  """
13
+ llm = get_llm()
14
 
15
  prompt = f"""
16
  Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
 
19
  Sentence: "{sentence}"
20
  """
21
 
22
+ response = llm.invoke(prompt)
23
+ location_list = ast.literal_eval(response.content.strip("```python\n").strip())
24
+ if location_list:
25
+ return location_list[0]
26
+ else:
27
+ return ""
 
 
 
 
 
 
28
 
29
  def detectTable(sql_query):
30
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'