Spaces:
Running
Running
File size: 3,503 Bytes
816ccb1 9d0df30 816ccb1 9d0df30 816ccb1 9d0df30 816ccb1 451c534 9d0df30 816ccb1 9d0df30 27e057c 451c534 9d0df30 451c534 6d95389 9d0df30 4cc23e8 9d0df30 451c534 9d0df30 27e057c 9d0df30 27e057c 816ccb1 6d95389 816ccb1 6d95389 9676b96 816ccb1 23adcaf 6d95389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Set up device: use GPU if available, else CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the fine-tuned model and tokenizer.
model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
# For CPU inference, convert the model to FP32 for better compatibility.
if device.type == "cpu":
model = model.float()
# Optionally compile the model for speed improvements (requires PyTorch 2.0+).
try:
model = torch.compile(model)
except Exception as e:
print("torch.compile optimization failed:", e)
def generate_sql(context: str, query: str) -> str:
"""
Generates a SQL query given the provided context and natural language query.
Constructs a prompt from the inputs, then performs deterministic generation
using beam search with repetition handling.
"""
prompt = f"""Context:
{context}
Query:
{query}
Response:
"""
# Tokenize the prompt with truncation and max length; move to device.
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
# Ensure the decoder start token is set.
if model.config.decoder_start_token_id is None:
model.config.decoder_start_token_id = tokenizer.pad_token_id
# Generate SQL output using no_grad for optimized CPU usage.
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
decoder_start_token_id=model.config.decoder_start_token_id,
max_new_tokens=100,
temperature=0.1,
num_beams=5,
repetition_penalty=1.2,
early_stopping=True,
)
# Decode and clean the generated SQL statement.
generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
generated_sql = generated_sql.split(";")[0].strip() + ";" # Keep only the first valid SQL query
return generated_sql
# Guide text with detailed instructions and an example.
guide_text = """
**Overview:**
This app uses a fine-tuned FLAN-T5 model to generate SQL queries based on your inputs.
**How to Use:**
- **Context:** Enter your database schema (table definitions, DDL statements, sample data).
- **Query:** Enter a natural language query describing the desired SQL operation.
- Click **Generate SQL** to see the model-generated SQL query.
**Example:**
- **Context:**
CREATE TABLE students (id INT PRIMARY KEY, name VARCHAR(100), age INT, grade CHAR(1)); INSERT INTO students (id, name, age, grade) VALUES (1, 'Alice', 14, 'A'), (2, 'Bob', 15, 'B');
- **Query:**
Retrieve the names of students who are 15 years old.
The generated SQL might look like:
SELECT name FROM students WHERE age = 15;
"""
# Create Gradio interface.
iface = gr.Interface(
fn=generate_sql,
inputs=[
gr.Textbox(lines=8, label="Context", placeholder="Enter table schema, sample data, etc."),
gr.Textbox(lines=2, label="Query", placeholder="Enter your natural language query here...")
],
outputs="text",
title="Text-to-SQL Generator",
description=guide_text,
theme="default", # Use default theme to avoid loading warnings
flagging_mode="never" # Use flagging_mode instead of deprecated allow_flagging
)
iface.launch()
|