|
from climateqa.engine.talk_to_data.myVanna import MyVanna |
|
from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates |
|
import sqlite3 |
|
import os |
|
import pandas as pd |
|
from climateqa.engine.llm import get_llm |
|
|
|
from dotenv import load_dotenv |
|
import ast |
|
|
|
load_dotenv() |
|
|
|
|
|
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') |
|
PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY') |
|
INDEX_NAME = os.getenv('VANNA_INDEX_NAME') |
|
VANNA_MODEL = os.getenv('VANNA_MODEL') |
|
|
|
|
|
|
|
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4}) |
|
db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db") |
|
vn.connect_to_sqlite(db_vanna_path) |
|
|
|
llm = get_llm(provider="openai") |
|
|
|
def ask_llm_to_add_table_names(sql_query, llm): |
|
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content |
|
return sql_with_table_names |
|
|
|
def ask_llm_column_names(sql_query, llm): |
|
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content |
|
columns_list = ast.literal_eval(columns.strip("```python\n").strip()) |
|
return columns_list |
|
|
|
def ask_vanna(query): |
|
try : |
|
location = detect_location_with_openai(OPENAI_API_KEY, query) |
|
if location: |
|
|
|
coords = loc2coords(location) |
|
user_input = query.lower().replace(location.lower(), f"lat, long : {coords}") |
|
|
|
relevant_tables = detect_relevant_tables(user_input, llm) |
|
coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))] |
|
user_input_with_coords = replace_coordonates(coords, user_input, coords_tables) |
|
|
|
sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False) |
|
|
|
return sql_query, result_dataframe, figure |
|
|
|
else : |
|
empty_df = pd.DataFrame() |
|
empty_fig = {} |
|
return "", empty_df, empty_fig |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
empty_df = pd.DataFrame() |
|
empty_fig = {} |
|
return "", empty_df, empty_fig |