File size: 1,895 Bytes
f78a3ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import streamlit as st
import time  # For measuring time taken
from sql_query_generator.config import Instructions
from sql_query_generator.utils import load_json, format_prompt
from sql_query_generator.generator import load_model, generate_sql

# Load Schema and Metadata
try:
    schema = load_json("data/table_create.json")
    metadata = load_json("data/tables_metadata.json")
except FileNotFoundError:
    st.error("Schema or metadata files not found. Please ensure they are in the `data/` directory.")
    st.stop()

# Load model and tokenizer
tokenizer, model = load_model()

# Streamlit Interface
st.title("Dynamic SQL Query Generator")
st.markdown("Enter your query below to generate a SQL query based on the provided schema and metadata.")

# Sidebar
with st.sidebar:
    # st.header("Options")
    # use_gpu = st.checkbox("Use GPU (if available)", value=False)
    device = "cpu"

    st.header("Example Queries")
    st.write("- Get fire budget for Ada county for the year 2023")
    st.write("- Retrieve population count for all cities")

# Main Input
question = st.text_input("Enter your query:")

if question:
    st.write("Generating SQL for the query...")
    start_time = time.time()  # Start measuring time

    # Format prompt
    formatted_prompt = format_prompt(question, schema, metadata, Instructions)
    prompt_inputs = {
        "formatted_prompt": formatted_prompt,
        "schema": schema,
        "metadata": metadata,
        "instructions": Instructions,
    }

    # Generate SQL Query
    sql_query = generate_sql(question, prompt_inputs, tokenizer, model, device=device)
    end_time = time.time()  # End measuring time

    # Display Results
    st.subheader("Generated SQL Query:")
    st.code(sql_query, language="sql")

    # Show time taken
    time_taken = end_time - start_time
    st.subheader("Time Taken:")
    st.write(f"{time_taken:.2f} seconds")