dgjx's picture
Update app.py
5e1e57a verified
raw
history blame
3.37 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# εŠ θ½½ζ¨‘εž‹ε’Œεˆ†θ―ε™¨
model_name = "defog/sqlcoder-7b-2" # δ½Ώη”¨ζ›΄ζ–°ηš„ζ¨‘εž‹δ»₯ζι«˜ζ€§θƒ½
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # ι™δ½Žε†…ε­˜ε η”¨
def generate_sql(user_question, instructions, create_table_statements):
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{user_question}`
{instructions}
DDL statements:
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question `{user_question}`:
```sql
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=150)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql_query
question = f"What are our top 3 products by revenue in the New York region?"
instructions = f"""- if the question cannot be answered given the database schema, return "I do not know"
- recall that the current date in YYYY-MM-DD format is 2024-09-15
"""
schema = f"""CREATE TABLE products (
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
name VARCHAR(50), -- Name of the product
price DECIMAL(10,2), -- Price of each unit of the product
quantity INTEGER -- Current quantity in stock
);
CREATE TABLE customers (
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
name VARCHAR(50), -- Name of the customer
address VARCHAR(100) -- Mailing address of the customer
);
CREATE TABLE salespeople (
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
name VARCHAR(50), -- Name of the salesperson
region VARCHAR(50) -- Geographic sales region
);
CREATE TABLE sales (
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
product_id INTEGER, -- ID of product sold
customer_id INTEGER, -- ID of customer who made purchase
salesperson_id INTEGER, -- ID of salesperson who made the sale
sale_date DATE, -- Date the sale occurred
quantity INTEGER -- Quantity of product sold
);
CREATE TABLE product_suppliers (
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
product_id INTEGER, -- Product ID supplied
supply_price DECIMAL(10,2) -- Unit price charged by supplier
);
-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""
demo = gr.Interface(
fn=generate_sql,
title="SQLCoder-8b",
description="Defog's SQLCoder-8B is a state of the art-models for generating SQL queries from natural language. ",
inputs=[
gr.Textbox(label="User Question", placeholder="Enter your question here...", value=question),
gr.Textbox(label="Instructions (optional)", placeholder="Enter any additional instructions here...", value=instructions),
gr.Textbox(label="Create Table Statements", placeholder="Enter DDL statements here...", value=schema),
],
outputs="text",
)
if __name__ == "__main__":
demo.launch(share=True)