Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import duckdb | |
import spaces | |
import gradio as gr | |
import pandas as pd | |
from llama_cpp import Llama | |
# from dotenv import load_dotenv | |
from huggingface_hub import hf_hub_download | |
# load_dotenv() | |
# Height of the Tabs Text Area | |
TAB_LINES = 8 | |
# Load Token | |
md_token = os.getenv('MD_TOKEN') | |
# Connect to DB | |
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}") | |
# Custom CSS styling | |
custom_css = """ | |
.gradio-container { | |
background-color: #f0f4f8; | |
} | |
.logo { | |
max-width: 200px; | |
margin: 20px auto; | |
display: block; | |
} | |
.gr-button { | |
background-color: #4a90e2 !important; | |
} | |
.gr-button:hover { | |
background-color: #3a7bc8 !important; | |
} | |
""" | |
print('Loading Model...') | |
# Load Model | |
# @spaces.GPU | |
# def load_model(): | |
llama = Llama( | |
model_path=hf_hub_download( | |
repo_id="motherduckdb/DuckDB-NSQL-7B-v0.1-GGUF", | |
filename="DuckDB-NSQL-7B-v0.1-q8_0.gguf", | |
local_dir='.' | |
), | |
n_ctx=2048, | |
n_gpu_layers=0 | |
) | |
# return llama | |
# llama = load_model() | |
print('Model Loaded...') | |
# Get Databases | |
def get_databases(): | |
databases = conn.execute("PRAGMA show_databases").fetchall() | |
return [item[0] for item in databases] | |
# Get Tables | |
def get_tables(database): | |
conn.execute(f"USE {database}") | |
tables = conn.execute("SHOW TABLES").fetchall() | |
return [table[0] for table in tables] | |
# Update Tables | |
def update_tables(selected_db): | |
tables = get_tables(selected_db) | |
return gr.update(choices=tables) | |
# Get Schema | |
def get_schema(table): | |
conn.execute(f"SELECT * FROM '{table}' LIMIT 1;") | |
result = conn.sql(f"SELECT sql FROM duckdb_tables() where table_name ='{table}';").df() | |
ddl_create = result.iloc[0,0] | |
return ddl_create | |
# Get Prompt | |
def get_prompt(schema, query_input): | |
text = f""" | |
### Instruction: | |
Your task is to generate valid duckdb SQL to answer the following question. | |
### Input: | |
Here is the database schema that the SQL query will run on: | |
{schema} | |
### Question: | |
{query_input} | |
### Response (use duckdb shorthand if possible): | |
""" | |
return text | |
# Generate SQL | |
# @spaces.GPU | |
def generate_sql(prompt): | |
result = llama(prompt, temperature=0.1, max_tokens=1000) | |
return result["choices"][0]["text"] | |
def text2sql(table, query_input): | |
if table is None: | |
return { | |
table_schema: "", | |
input_prompt: "", | |
generated_query: "", | |
result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}]) | |
} | |
schema = get_schema(table) | |
prompt = get_prompt(schema, query_input) | |
try: | |
result = generate_sql(prompt) | |
except Exception as e: | |
return { | |
table_schema: schema, | |
input_prompt: prompt, | |
generated_query: "", | |
result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}]) | |
} | |
try: | |
query_result = conn.sql(result).df() | |
conn.close() | |
except Exception as e: | |
return { | |
table_schema: schema, | |
input_prompt: prompt, | |
generated_query: result, | |
result_output:pd.DataFrame([{"error": f"β Unable to get the SQL query based on the text. {e}"}]) | |
} | |
conn.close() | |
return { | |
table_schema: schema, | |
input_prompt: prompt, | |
generated_query: result, | |
result_output:query_result | |
} | |
# Load Databases Names | |
databases = get_databases() | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
gr.Markdown(""" | |
<div style='text-align: center;'> | |
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong> | |
<br> | |
<span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1, variant='panel'): | |
database_dropdown = gr.Dropdown(choices=databases, label="Select Database", interactive=True) | |
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
with gr.Column(scale=2): | |
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...") | |
generate_query_button = gr.Button("Run Query", variant="primary") | |
with gr.Tabs(): | |
with gr.Tab("Result"): | |
result_output = gr.DataFrame(label="Query Results", value=[], interactive=False) | |
with gr.Tab("SQL Query"): | |
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False) | |
with gr.Tab("Prompt"): | |
input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False) | |
with gr.Tab("Schema"): | |
table_schema = gr.Textbox(lines=TAB_LINES, label="Schema", value="", interactive=False) | |
database_dropdown.change(update_tables, inputs=database_dropdown, outputs=tables_dropdown) | |
generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output]) | |
if __name__ == "__main__": | |
demo.launch() | |