File size: 3,885 Bytes
90d8ee3
df8873e
687f4c4
5a153bd
 
90d8ee3
df8873e
fa87edc
5a153bd
df8873e
08343a9
 
 
 
 
 
 
90d8ee3
b2d9f7b
dd3a93a
 
 
 
 
 
 
 
 
 
 
df8873e
3a05353
08343a9
 
 
 
 
 
cf16b57
08343a9
 
 
 
 
 
 
90d8ee3
1c4b940
b2d9f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd3a93a
b2d9f7b
 
 
 
 
 
 
 
dd3a93a
b2d9f7b
 
 
 
 
 
 
 
 
 
 
dd3a93a
1c4b940
 
b030295
5e1e57a
1c4b940
b030295
 
 
1c4b940
 
 
df8873e
90d8ee3
 
c802c0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import sqlparse
# from modelscope import snapshot_download

# 加载模型和分词器
model_name = "defog/llama-3-sqlcoder-8b"  # 使用更新的模型以提高性能
# model_name = snapshot_download("stevie/llama-3-sqlcoder-8b")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto",
    use_cache=True,
)

def generate_sql(user_question, instructions, create_table_statements):
    prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
{instructions}

DDL statements:
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)


question = f"What are our top 3 products by revenue in the New York region?"
instructions = f"""- if the question cannot be answered given the database schema, return "I do not know"
- recall that the current date in YYYY-MM-DD format is 2024-09-15
"""
schema = f"""CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
  customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
  name VARCHAR(50), -- Name of the customer
  address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson 
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region 
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred 
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id 
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""

demo = gr.Interface(
    fn=generate_sql,
    title="SQLCoder-8b",
    description="Defog's SQLCoder-8B is a state of the art-models for generating SQL queries from natural language. ",
    inputs=[
        gr.Textbox(label="User Question", placeholder="Enter your question here...", value=question),
        gr.Textbox(label="Instructions (optional)", placeholder="Enter any additional instructions here...", value=instructions),
        gr.Textbox(label="Create Table Statements", placeholder="Enter DDL statements here...", value=schema),
    ],
    outputs="text",
)


if __name__ == "__main__":
    demo.launch(share=True)