|
from vanna.base import VannaBase |
|
from pinecone import Pinecone |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
import pandas as pd |
|
import hashlib |
|
|
|
class MyCustomVectorDB(VannaBase): |
|
|
|
""" |
|
VectorDB class for storing and retrieving vectors from Pinecone. |
|
|
|
args : |
|
config (dict) : Configuration dictionary containing the Pinecone API key and the index name : |
|
- pc_api_key (str) : Pinecone API key |
|
- index_name (str) : Pinecone index name |
|
- top_k (int) : Number of top results to return (default = 2) |
|
|
|
""" |
|
|
|
def __init__(self,config): |
|
super().__init__(config = config) |
|
try : |
|
self.api_key = config.get('pc_api_key') |
|
self.index_name = config.get('index_name') |
|
except : |
|
raise Exception("Please provide the Pinecone API key and the index name") |
|
|
|
self.pc = Pinecone(api_key = self.api_key) |
|
self.index = self.pc.Index(self.index_name) |
|
self.top_k = config.get('top_k', 2) |
|
self.embeddings = get_embeddings_function() |
|
|
|
|
|
def check_embedding(self, id, namespace): |
|
fetched = self.index.fetch(ids = [id], namespace = namespace) |
|
if fetched['vectors'] == {}: |
|
return False |
|
return True |
|
|
|
def generate_hash_id(self, data: str) -> str: |
|
""" |
|
Generate a unique hash ID for the given data. |
|
|
|
Args: |
|
data (str): The input data to hash (e.g., a concatenated string of user attributes). |
|
|
|
Returns: |
|
str: A unique hash ID as a hexadecimal string. |
|
""" |
|
|
|
data_bytes = data.encode('utf-8') |
|
hash_object = hashlib.sha256(data_bytes) |
|
hash_id = hash_object.hexdigest() |
|
|
|
return hash_id |
|
|
|
def add_ddl(self, ddl: str, **kwargs) -> str: |
|
id = self.generate_hash_id(ddl) + '_ddl' |
|
|
|
if self.check_embedding(id, 'ddl'): |
|
print(f"DDL having id {id} already exists") |
|
return id |
|
|
|
self.index.upsert( |
|
vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})], |
|
namespace = 'ddl' |
|
) |
|
|
|
return id |
|
|
|
def add_documentation(self, doc: str, **kwargs) -> str: |
|
id = self.generate_hash_id(doc) + '_doc' |
|
|
|
if self.check_embedding(id, 'documentation'): |
|
print(f"Documentation having id {id} already exists") |
|
return id |
|
|
|
self.index.upsert( |
|
vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})], |
|
namespace = 'documentation' |
|
) |
|
|
|
return id |
|
|
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str: |
|
id = self.generate_hash_id(question) + '_sql' |
|
|
|
if self.check_embedding(id, 'question_sql'): |
|
print(f"Question-SQL pair having id {id} already exists") |
|
return id |
|
|
|
self.index.upsert( |
|
vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})], |
|
namespace = 'question_sql' |
|
) |
|
|
|
return id |
|
|
|
def get_related_ddl(self, question: str, **kwargs) -> list: |
|
res = self.index.query( |
|
vector=self.embeddings.embed_query(question), |
|
top_k=self.top_k, |
|
namespace='ddl', |
|
include_metadata=True |
|
) |
|
|
|
return [match['metadata']['ddl'] for match in res['matches']] |
|
|
|
def get_related_documentation(self, question: str, **kwargs) -> list: |
|
res = self.index.query( |
|
vector=self.embeddings.embed_query(question), |
|
top_k=self.top_k, |
|
namespace='documentation', |
|
include_metadata=True |
|
) |
|
|
|
return [match['metadata']['doc'] for match in res['matches']] |
|
|
|
def get_similar_question_sql(self, question: str, **kwargs) -> list: |
|
res = self.index.query( |
|
vector=self.embeddings.embed_query(question), |
|
top_k=self.top_k, |
|
namespace='question_sql', |
|
include_metadata=True |
|
) |
|
|
|
return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']] |
|
|
|
def get_training_data(self, **kwargs) -> pd.DataFrame: |
|
|
|
list_of_data = [] |
|
|
|
namespaces = ['ddl', 'documentation', 'question_sql'] |
|
|
|
for namespace in namespaces: |
|
|
|
data = self.index.query( |
|
top_k=10000, |
|
namespace=namespace, |
|
include_metadata=True, |
|
include_values=False |
|
) |
|
|
|
for match in data['matches']: |
|
list_of_data.append(match['metadata']) |
|
|
|
return pd.DataFrame(list_of_data) |
|
|
|
|
|
|
|
def remove_training_data(self, id: str, **kwargs) -> bool: |
|
if id.endswith("_ddl"): |
|
self.Index.delete(ids=[id], namespace="_ddl") |
|
return True |
|
if id.endswith("_sql"): |
|
self.index.delete(ids=[id], namespace="_sql") |
|
return True |
|
|
|
if id.endswith("_doc"): |
|
self.Index.delete(ids=[id], namespace="_doc") |
|
return True |
|
|
|
return False |
|
|
|
def generate_embedding(self, text, **kwargs): |
|
|
|
pass |
|
|
|
|
|
def get_sql_prompt( |
|
self, |
|
initial_prompt : str, |
|
question: str, |
|
question_sql_list: list, |
|
ddl_list: list, |
|
doc_list: list, |
|
**kwargs, |
|
): |
|
""" |
|
Example: |
|
```python |
|
vn.get_sql_prompt( |
|
question="What are the top 10 customers by sales?", |
|
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}], |
|
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"], |
|
doc_list=["The customers table contains information about customers and their sales."], |
|
) |
|
|
|
``` |
|
|
|
This method is used to generate a prompt for the LLM to generate SQL. |
|
|
|
Args: |
|
question (str): The question to generate SQL for. |
|
question_sql_list (list): A list of questions and their corresponding SQL statements. |
|
ddl_list (list): A list of DDL statements. |
|
doc_list (list): A list of documentation. |
|
|
|
Returns: |
|
any: The prompt for the LLM to generate SQL. |
|
""" |
|
|
|
if initial_prompt is None: |
|
initial_prompt = f"You are a {self.dialect} expert. " + \ |
|
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " |
|
|
|
initial_prompt = self.add_ddl_to_prompt( |
|
initial_prompt, ddl_list, max_tokens=self.max_tokens |
|
) |
|
|
|
if self.static_documentation != "": |
|
doc_list.append(self.static_documentation) |
|
|
|
initial_prompt = self.add_documentation_to_prompt( |
|
initial_prompt, doc_list, max_tokens=self.max_tokens |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
initial_prompt += ( |
|
"===Response Guidelines \n" |
|
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n" |
|
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n" |
|
"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n" |
|
"4. Please use the most relevant table(s). \n" |
|
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n" |
|
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n" |
|
f"7. Add a description of the table in the result of the sql query, if relevant. \n" |
|
"8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n" |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
message_log = [self.system_message(initial_prompt)] |
|
|
|
for example in question_sql_list: |
|
if example is None: |
|
print("example is None") |
|
else: |
|
if example is not None and "question" in example and "sql" in example: |
|
message_log.append(self.user_message(example["question"])) |
|
message_log.append(self.assistant_message(example["sql"])) |
|
|
|
message_log.append(self.user_message(question)) |
|
|
|
return message_log |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|