update OpenAI usage from Vanna
Browse files- app.py +10 -1
- climateqa/engine/talk_to_data/main.py +5 -18
- climateqa/engine/talk_to_data/utils.py +10 -16
app.py
CHANGED
@@ -13,6 +13,7 @@ from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
from climateqa.engine.talk_to_data.main import ask_vanna
|
|
|
16 |
|
17 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
18 |
from front.utils import process_figures
|
@@ -77,6 +78,14 @@ else :
|
|
77 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
78 |
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
82 |
print("chat cqa - message received")
|
@@ -126,7 +135,7 @@ def create_drias_tab():
|
|
126 |
show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
|
127 |
|
128 |
vanna_display = gr.Plot()
|
129 |
-
vanna_direct_question.submit(
|
130 |
|
131 |
# # UI Layout Components
|
132 |
def cqa_tab(tab_name):
|
|
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
from climateqa.engine.talk_to_data.main import ask_vanna
|
16 |
+
from climateqa.engine.talk_to_data.myVanna import MyVanna
|
17 |
|
18 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
19 |
from front.utils import process_figures
|
|
|
78 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
79 |
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
|
80 |
|
81 |
+
#Vanna object
|
82 |
+
|
83 |
+
vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
|
84 |
+
db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
|
85 |
+
vn.connect_to_sqlite(db_vanna_path)
|
86 |
+
|
87 |
+
def ask_vanna_query(query):
|
88 |
+
return ask_vanna(vn, db_vanna_path, query)
|
89 |
|
90 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
91 |
print("chat cqa - message received")
|
|
|
135 |
show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
|
136 |
|
137 |
vanna_display = gr.Plot()
|
138 |
+
vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
|
139 |
|
140 |
# # UI Layout Components
|
141 |
def cqa_tab(tab_name):
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -4,24 +4,10 @@ import sqlite3
|
|
4 |
import os
|
5 |
import pandas as pd
|
6 |
from climateqa.engine.llm import get_llm
|
7 |
-
|
8 |
-
from dotenv import load_dotenv
|
9 |
import ast
|
10 |
|
11 |
-
load_dotenv()
|
12 |
-
|
13 |
-
|
14 |
-
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
15 |
-
PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
|
16 |
-
INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
|
17 |
-
VANNA_MODEL = os.getenv('VANNA_MODEL')
|
18 |
|
19 |
|
20 |
-
#Vanna object
|
21 |
-
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
|
22 |
-
db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
|
23 |
-
vn.connect_to_sqlite(db_vanna_path)
|
24 |
-
|
25 |
llm = get_llm(provider="openai")
|
26 |
|
27 |
def ask_llm_to_add_table_names(sql_query, llm):
|
@@ -33,9 +19,10 @@ def ask_llm_column_names(sql_query, llm):
|
|
33 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
34 |
return columns_list
|
35 |
|
36 |
-
def ask_vanna(query):
|
|
|
37 |
try :
|
38 |
-
location = detect_location_with_openai(
|
39 |
if location:
|
40 |
|
41 |
coords = loc2coords(location)
|
@@ -51,10 +38,10 @@ def ask_vanna(query):
|
|
51 |
|
52 |
else :
|
53 |
empty_df = pd.DataFrame()
|
54 |
-
empty_fig =
|
55 |
return "", empty_df, empty_fig
|
56 |
except Exception as e:
|
57 |
print(f"Error: {e}")
|
58 |
empty_df = pd.DataFrame()
|
59 |
-
empty_fig =
|
60 |
return "", empty_df, empty_fig
|
|
|
4 |
import os
|
5 |
import pandas as pd
|
6 |
from climateqa.engine.llm import get_llm
|
|
|
|
|
7 |
import ast
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
llm = get_llm(provider="openai")
|
12 |
|
13 |
def ask_llm_to_add_table_names(sql_query, llm):
|
|
|
19 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
20 |
return columns_list
|
21 |
|
22 |
+
def ask_vanna(vn,db_vanna_path, query):
|
23 |
+
|
24 |
try :
|
25 |
+
location = detect_location_with_openai(query)
|
26 |
if location:
|
27 |
|
28 |
coords = loc2coords(location)
|
|
|
38 |
|
39 |
else :
|
40 |
empty_df = pd.DataFrame()
|
41 |
+
empty_fig = None
|
42 |
return "", empty_df, empty_fig
|
43 |
except Exception as e:
|
44 |
print(f"Error: {e}")
|
45 |
empty_df = pd.DataFrame()
|
46 |
+
empty_fig = None
|
47 |
return "", empty_df, empty_fig
|
climateqa/engine/talk_to_data/utils.py
CHANGED
@@ -4,13 +4,13 @@ import pandas as pd
|
|
4 |
from geopy.geocoders import Nominatim
|
5 |
import sqlite3
|
6 |
import ast
|
|
|
7 |
|
8 |
-
|
9 |
-
def detect_location_with_openai(api_key, sentence):
|
10 |
"""
|
11 |
-
Detects locations in a sentence using OpenAI's API.
|
12 |
"""
|
13 |
-
|
14 |
|
15 |
prompt = f"""
|
16 |
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
@@ -19,18 +19,12 @@ def detect_location_with_openai(api_key, sentence):
|
|
19 |
Sentence: "{sentence}"
|
20 |
"""
|
21 |
|
22 |
-
response =
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
max_tokens=100,
|
29 |
-
temperature=0
|
30 |
-
)
|
31 |
-
|
32 |
-
return response.choices[0].message.content.split("\n")[1][2:-2]
|
33 |
-
|
34 |
|
35 |
def detectTable(sql_query):
|
36 |
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
|
|
4 |
from geopy.geocoders import Nominatim
|
5 |
import sqlite3
|
6 |
import ast
|
7 |
+
from climateqa.engine.llm import get_llm
|
8 |
|
9 |
+
def detect_location_with_openai(sentence):
|
|
|
10 |
"""
|
11 |
+
Detects locations in a sentence using OpenAI's API via LangChain.
|
12 |
"""
|
13 |
+
llm = get_llm()
|
14 |
|
15 |
prompt = f"""
|
16 |
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
|
|
19 |
Sentence: "{sentence}"
|
20 |
"""
|
21 |
|
22 |
+
response = llm.invoke(prompt)
|
23 |
+
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
24 |
+
if location_list:
|
25 |
+
return location_list[0]
|
26 |
+
else:
|
27 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def detectTable(sql_query):
|
30 |
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|