timeki's picture
update OpenAI usage from Vanna
8bd064f
from climateqa.engine.talk_to_data.myVanna import MyVanna
from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
import sqlite3
import os
import pandas as pd
from climateqa.engine.llm import get_llm
import ast
llm = get_llm(provider="openai")
def ask_llm_to_add_table_names(sql_query, llm):
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
return sql_with_table_names
def ask_llm_column_names(sql_query, llm):
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
return columns_list
def ask_vanna(vn,db_vanna_path, query):
try :
location = detect_location_with_openai(query)
if location:
coords = loc2coords(location)
user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
relevant_tables = detect_relevant_tables(user_input, llm)
coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
return sql_query, result_dataframe, figure
else :
empty_df = pd.DataFrame()
empty_fig = None
return "", empty_df, empty_fig
except Exception as e:
print(f"Error: {e}")
empty_df = pd.DataFrame()
empty_fig = None
return "", empty_df, empty_fig