File size: 3,503 Bytes
816ccb1
 
 
 
9d0df30
816ccb1
 
9d0df30
 
816ccb1
 
 
9d0df30
 
 
 
 
 
 
 
 
 
816ccb1
 
451c534
 
9d0df30
816ccb1
 
 
 
 
 
 
9d0df30
27e057c
451c534
9d0df30
451c534
 
 
6d95389
9d0df30
 
 
 
 
4cc23e8
9d0df30
 
 
 
451c534
9d0df30
27e057c
9d0df30
27e057c
816ccb1
6d95389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816ccb1
 
 
 
 
 
 
 
6d95389
9676b96
 
816ccb1
 
23adcaf
6d95389
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Set up device: use GPU if available, else CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the fine-tuned model and tokenizer.
model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

# For CPU inference, convert the model to FP32 for better compatibility.
if device.type == "cpu":
    model = model.float()

# Optionally compile the model for speed improvements (requires PyTorch 2.0+).
try:
    model = torch.compile(model)
except Exception as e:
    print("torch.compile optimization failed:", e)

def generate_sql(context: str, query: str) -> str:
    """
    Generates a SQL query given the provided context and natural language query.
    Constructs a prompt from the inputs, then performs deterministic generation
    using beam search with repetition handling.
    """
    prompt = f"""Context:
{context}
Query:
{query}
Response:
"""
    # Tokenize the prompt with truncation and max length; move to device.
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
    
    # Ensure the decoder start token is set.
    if model.config.decoder_start_token_id is None:
        model.config.decoder_start_token_id = tokenizer.pad_token_id

    # Generate SQL output using no_grad for optimized CPU usage.
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            decoder_start_token_id=model.config.decoder_start_token_id,
            max_new_tokens=100,
            temperature=0.1,       
            num_beams=5,
            repetition_penalty=1.2,
            early_stopping=True,
        )
    
    # Decode and clean the generated SQL statement.
    generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    generated_sql = generated_sql.split(";")[0].strip() + ";"  # Keep only the first valid SQL query
    return generated_sql

# Guide text with detailed instructions and an example.
guide_text = """
**Overview:**
This app uses a fine-tuned FLAN-T5 model to generate SQL queries based on your inputs.

**How to Use:**
- **Context:** Enter your database schema (table definitions, DDL statements, sample data).
- **Query:** Enter a natural language query describing the desired SQL operation.
- Click **Generate SQL** to see the model-generated SQL query.

**Example:**
- **Context:**
CREATE TABLE students (id INT PRIMARY KEY, name VARCHAR(100), age INT, grade CHAR(1)); INSERT INTO students (id, name, age, grade) VALUES (1, 'Alice', 14, 'A'), (2, 'Bob', 15, 'B');

- **Query:**
Retrieve the names of students who are 15 years old.

The generated SQL might look like:
SELECT name FROM students WHERE age = 15;
"""

# Create Gradio interface.
iface = gr.Interface(
    fn=generate_sql,
    inputs=[
        gr.Textbox(lines=8, label="Context", placeholder="Enter table schema, sample data, etc."),
        gr.Textbox(lines=2, label="Query", placeholder="Enter your natural language query here...")
    ],
    outputs="text",
    title="Text-to-SQL Generator",
    description=guide_text,
    theme="default",  # Use default theme to avoid loading warnings
    flagging_mode="never"  # Use flagging_mode instead of deprecated allow_flagging
)

iface.launch()