Mustehson's picture
Change Model
902da82 verified
import os
import torch
import duckdb
import spaces
import lancedb
import gradio as gr
import pandas as pd
import pyarrow as pa
from langchain import hub
from langsmith import traceable
from sentence_transformers import SentenceTransformer
from langchain_huggingface.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
# Height of the Tabs Text Area
TAB_LINES = 8
#----------CONNECT TO DATABASE----------
md_token = os.getenv('MD_TOKEN')
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
#---------------------------------------
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
#---------------------------------------
#--------------LanceDB-------------
lance_db = lancedb.connect(
uri=os.getenv('lancedb_uri'),
api_key=os.getenv('lancedb_api_key'),
region=os.getenv('lancedb_region')
)
lance_schema = pa.schema([
pa.field("vector", pa.list_(pa.float32())),
pa.field("sql-query", pa.utf8())
])
try:
table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
except:
table = lance_db.open_table(name="SQL-Queries")
#---------------------------------------
#-------LOAD HUGGINGFACE PIPELINE-------
tokenizer = AutoTokenizer.from_pretrained("defog/llama-3-sqlcoder-8b")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type= "nf4")
model = AutoModelForCausalLM.from_pretrained("defog/llama-3-sqlcoder-8b", quantization_config=quantization_config,
device_map="auto", torch_dtype=torch.bfloat16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
hf = HuggingFacePipeline(pipeline=pipe)
#---------------------------------------
#-----LOAD PROMPT FROM LANCHAIN HUB-----
prompt = hub.pull("sql-agent-prompt")
#---------------------------------------
#-----LOAD EMBEDDING MODEL-----
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
#---------------------------------------
#--------------ALL UTILS----------------
# Get Databases
def get_schemas():
schemas = conn.execute("""
SELECT DISTINCT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
""").fetchall()
return [item[0] for item in schemas]
# Get Tables
def get_tables(schema_name):
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
return [table[0] for table in tables]
# Update Tables
def update_tables(schema_name):
tables = get_tables(schema_name)
return gr.update(choices=tables)
# Get Schema
def get_table_schema(table):
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
ddl_create = result.iloc[0,0]
parent_database = result.iloc[0,1]
schema_name = result.iloc[0,2]
full_path = f"{parent_database}.{schema_name}.{table}"
if schema_name != "main":
old_path = f"{schema_name}.{table}"
else:
old_path = table
ddl_create = ddl_create.replace(old_path, full_path)
return ddl_create
# Get Prompt
def get_prompt(schema, query_input):
return prompt.format(schema=schema, query_input=query_input)
@spaces.GPU(duration=60)
@traceable()
def generate_sql(prompt):
result = hf.invoke(prompt)
return result.strip()
@spaces.GPU(duration=10)
def embed_query(sql_query):
print(f'Creating Emebeddings {sql_query}')
if sql_query is not None:
embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist()
return embeddings
def log2lancedb(embeddings, sql_query):
data = [{
"sql-query": sql_query,
"vector": embeddings
}]
table.add(data)
print(f'Added to Lance DB.')
#---------------------------------------
# Generate SQL
def text2sql(table, query_input):
if table is None:
return {
table_schema: "",
input_prompt: "",
generated_query: "",
result_output:pd.DataFrame([{"error": "❌ Please Select Table, Schema.}"}])
}
schema = get_table_schema(table)
print(f'Schema Generated...')
prompt = get_prompt(schema, query_input)
print(f'Prompt Generated...')
try:
print(f'Generating SQL... {model.device}')
result = generate_sql(prompt)
print('SQL Generated...')
except Exception as e:
return {
table_schema: schema,
input_prompt: prompt,
generated_query: "",
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
}
try:
embeddings = embed_query(result)
log2lancedb(embeddings, result)
except Exception as e:
print("Error Generating and Logging Embeddings...")
print(e)
try:
query_result = conn.sql(result).df()
except Exception as e:
return {
table_schema: schema,
input_prompt: prompt,
generated_query: result,
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
}
return {
table_schema: schema,
input_prompt: prompt,
generated_query: result,
result_output:query_result
}
# Custom CSS styling
custom_css = """
.gradio-container {
background-color: #f0f4f8;
}
.logo {
max-width: 200px;
margin: 20px auto;
display: block;
}
.gr-button {
background-color: #4a90e2 !important;
}
.gr-button:hover {
background-color: #3a7bc8 !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
gr.Markdown("""
<div style='text-align: center;'>
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
<br>
<span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span>
</div>
""")
with gr.Row():
with gr.Column(scale=1, variant='panel'):
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
with gr.Column(scale=2):
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
with gr.Row():
with gr.Column(scale=7):
pass
with gr.Column(scale=1):
generate_query_button = gr.Button("Run Query", variant="primary")
with gr.Tabs():
with gr.Tab("Result"):
result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
with gr.Tab("SQL Query"):
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
with gr.Tab("Prompt"):
input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
with gr.Tab("Schema"):
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
if __name__ == "__main__":
demo.launch()