File size: 2,571 Bytes
28684d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abafbcc
28684d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from climateqa.engine.talk_to_data.myVanna import MyVanna
from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
import sqlite3
import os
import pandas as pd
from climateqa.engine.llm import get_llm

from dotenv import load_dotenv
import ast

load_dotenv()


OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
VANNA_MODEL = os.getenv('VANNA_MODEL')


#Vanna object
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
vn.connect_to_sqlite(db_vanna_path)

llm = get_llm(provider="openai")

def ask_llm_to_add_table_names(sql_query, llm):
    sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
    return sql_with_table_names

def ask_llm_column_names(sql_query, llm):
    columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
    columns_list = ast.literal_eval(columns.strip("```python\n").strip())
    return columns_list

def ask_vanna(query):
    try :
        location = detect_location_with_openai(OPENAI_API_KEY, query)
        if location:

            coords = loc2coords(location)
            user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
            
            relevant_tables = detect_relevant_tables(user_input, llm)
            coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
            user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)

            sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)

            return sql_query, result_dataframe, figure

        else : 
            empty_df = pd.DataFrame()
            empty_fig = {}
            return "", empty_df, empty_fig
    except Exception as e:
        print(f"Error: {e}")
        empty_df = pd.DataFrame()
        empty_fig = {}
        return "", empty_df, empty_fig