Spaces:
Running
Running
File size: 2,061 Bytes
9e11341 cacc96f 9e11341 7794a8e cacc96f 0d53dda cacc96f 7794a8e 9e11341 cacc96f 9e11341 7794a8e cacc96f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import mysql.connector
from mysql.connector import Error
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the model and tokenizer
model_name = "premai-io/prem-1B-SQL"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_sql(natural_language_query):
"""Generate SQL query from natural language."""
# Define your schema information
schema_info = """
CREATE TABLE sales (
pizza_id DECIMAL(8,2) PRIMARY KEY,
order_id DECIMAL(8,2),
pizza_name_id VARCHAR(14),
quantity DECIMAL(4,2),
order_date DATE,
order_time VARCHAR(8),
unit_price DECIMAL(5,2),
total_price DECIMAL(5,2),
pizza_size VARCHAR(3),
pizza_category VARCHAR(7),
pizza_ingredients VARCHAR(97),
pizza_name VARCHAR(42)
);
"""
# Construct the prompt
prompt = f"""### Task: Generate a SQL query to answer the following question.
### Database Schema:
{schema_info}
### Question: {natural_language_query}
### SQL Query:"""
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
outputs = model.generate(
inputs["input_ids"],
max_length=512,
temperature=0.1,
do_sample=True,
top_p=0.95,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
# Decode and clean up the response
generated_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
sql_query = generated_query.split("### SQL Query:")[-1].strip()
return sql_query
def main():
# Gradio interface setup
iface = gr.Interface(
fn=generate_sql,
inputs="text",
outputs="text",
title="Natural Language to SQL Query Generator",
description="Enter a natural language query to generate the corresponding SQL query."
)
iface.launch()
if __name__ == "__main__":
main()
|