Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import duckdb | |
import spaces | |
import lancedb | |
import gradio as gr | |
import pandas as pd | |
import pyarrow as pa | |
from langchain import hub | |
from langsmith import traceable | |
from sentence_transformers import SentenceTransformer | |
from langchain_huggingface.llms import HuggingFacePipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline | |
# Height of the Tabs Text Area | |
TAB_LINES = 8 | |
#----------CONNECT TO DATABASE---------- | |
md_token = os.getenv('MD_TOKEN') | |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
#--------------------------------------- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print(f"Using GPU: {torch.cuda.get_device_name(device)}") | |
else: | |
device = torch.device("cpu") | |
print("Using CPU") | |
#--------------------------------------- | |
#--------------LanceDB------------- | |
lance_db = lancedb.connect( | |
uri=os.getenv('lancedb_uri'), | |
api_key=os.getenv('lancedb_api_key'), | |
region=os.getenv('lancedb_region') | |
) | |
lance_schema = pa.schema([ | |
pa.field("vector", pa.list_(pa.float32())), | |
pa.field("sql-query", pa.utf8()) | |
]) | |
try: | |
table = lance_db.create_table(name="SQL-Queries", schema=lance_schema) | |
except: | |
table = lance_db.open_table(name="SQL-Queries") | |
#--------------------------------------- | |
#-------LOAD HUGGINGFACE PIPELINE------- | |
tokenizer = AutoTokenizer.from_pretrained("defog/llama-3-sqlcoder-8b") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type= "nf4") | |
model = AutoModelForCausalLM.from_pretrained("defog/llama-3-sqlcoder-8b", quantization_config=quantization_config, | |
device_map="auto", torch_dtype=torch.bfloat16) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False) | |
hf = HuggingFacePipeline(pipeline=pipe) | |
#--------------------------------------- | |
#-----LOAD PROMPT FROM LANCHAIN HUB----- | |
prompt = hub.pull("sql-agent-prompt") | |
#--------------------------------------- | |
#-----LOAD EMBEDDING MODEL----- | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device) | |
#--------------------------------------- | |
#--------------ALL UTILS---------------- | |
# Get Databases | |
def get_schemas(): | |
schemas = conn.execute(""" | |
SELECT DISTINCT schema_name | |
FROM information_schema.schemata | |
WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
""").fetchall() | |
return [item[0] for item in schemas] | |
# Get Tables | |
def get_tables(schema_name): | |
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
return [table[0] for table in tables] | |
# Update Tables | |
def update_tables(schema_name): | |
tables = get_tables(schema_name) | |
return gr.update(choices=tables) | |
# Get Schema | |
def get_table_schema(table): | |
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
ddl_create = result.iloc[0,0] | |
parent_database = result.iloc[0,1] | |
schema_name = result.iloc[0,2] | |
full_path = f"{parent_database}.{schema_name}.{table}" | |
if schema_name != "main": | |
old_path = f"{schema_name}.{table}" | |
else: | |
old_path = table | |
ddl_create = ddl_create.replace(old_path, full_path) | |
return ddl_create | |
# Get Prompt | |
def get_prompt(schema, query_input): | |
return prompt.format(schema=schema, query_input=query_input) | |
def generate_sql(prompt): | |
result = hf.invoke(prompt) | |
return result.strip() | |
def embed_query(sql_query): | |
print(f'Creating Emebeddings {sql_query}') | |
if sql_query is not None: | |
embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist() | |
return embeddings | |
def log2lancedb(embeddings, sql_query): | |
data = [{ | |
"sql-query": sql_query, | |
"vector": embeddings | |
}] | |
table.add(data) | |
print(f'Added to Lance DB.') | |
#--------------------------------------- | |
# Generate SQL | |
def text2sql(table, query_input): | |
if table is None: | |
return { | |
table_schema: "", | |
input_prompt: "", | |
generated_query: "", | |
result_output:pd.DataFrame([{"error": "❌ Please Select Table, Schema.}"}]) | |
} | |
schema = get_table_schema(table) | |
print(f'Schema Generated...') | |
prompt = get_prompt(schema, query_input) | |
print(f'Prompt Generated...') | |
try: | |
print(f'Generating SQL... {model.device}') | |
result = generate_sql(prompt) | |
print('SQL Generated...') | |
except Exception as e: | |
return { | |
table_schema: schema, | |
input_prompt: prompt, | |
generated_query: "", | |
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) | |
} | |
try: | |
embeddings = embed_query(result) | |
log2lancedb(embeddings, result) | |
except Exception as e: | |
print("Error Generating and Logging Embeddings...") | |
print(e) | |
try: | |
query_result = conn.sql(result).df() | |
except Exception as e: | |
return { | |
table_schema: schema, | |
input_prompt: prompt, | |
generated_query: result, | |
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}]) | |
} | |
return { | |
table_schema: schema, | |
input_prompt: prompt, | |
generated_query: result, | |
result_output:query_result | |
} | |
# Custom CSS styling | |
custom_css = """ | |
.gradio-container { | |
background-color: #f0f4f8; | |
} | |
.logo { | |
max-width: 200px; | |
margin: 20px auto; | |
display: block; | |
} | |
.gr-button { | |
background-color: #4a90e2 !important; | |
} | |
.gr-button:hover { | |
background-color: #3a7bc8 !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
gr.Markdown(""" | |
<div style='text-align: center;'> | |
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong> | |
<br> | |
<span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1, variant='panel'): | |
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
with gr.Column(scale=2): | |
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...") | |
with gr.Row(): | |
with gr.Column(scale=7): | |
pass | |
with gr.Column(scale=1): | |
generate_query_button = gr.Button("Run Query", variant="primary") | |
with gr.Tabs(): | |
with gr.Tab("Result"): | |
result_output = gr.DataFrame(label="Query Results", value=[], interactive=False) | |
with gr.Tab("SQL Query"): | |
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False) | |
with gr.Tab("Prompt"): | |
input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False) | |
with gr.Tab("Schema"): | |
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False) | |
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown) | |
generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output]) | |
if __name__ == "__main__": | |
demo.launch() | |