File size: 3,242 Bytes
90d8ee3
df8873e
687f4c4
90d8ee3
df8873e
 
 
2725c13
90d8ee3
b2d9f7b
dd3a93a
 
 
 
 
 
 
 
 
 
 
df8873e
 
2725c13
df8873e
2725c13
df8873e
90d8ee3
1c4b940
b2d9f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd3a93a
b2d9f7b
 
 
 
 
 
 
 
dd3a93a
b2d9f7b
 
 
 
 
 
 
 
 
 
 
dd3a93a
1c4b940
 
 
 
b2d9f7b
 
 
1c4b940
 
 
df8873e
90d8ee3
 
c802c0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 加载模型和分词器
model_name = "defog/sqlcoder-7b-2"  # 使用更新的模型以提高性能
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")  # 降低内存占用

def generate_sql(user_question, instructions, create_table_statements):
    prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
{instructions}

DDL statements:
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_length=150)
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return sql_query


question = f"What are our top 3 products by revenue in the New York region?"
instructions = f"""- if the question cannot be answered given the database schema, return "I do not know"
- recall that the current date in YYYY-MM-DD format is 2024-09-15
"""
schema = f"""CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
  customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
  name VARCHAR(50), -- Name of the customer
  address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson 
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region 
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred 
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id 
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""

demo = gr.Interface(
    fn=generate_sql,
    inputs=[
        gr.Markdown("## SQL Query Generator"),
        gr.Textbox(label="User Question", placeholder="请输入您的问题...", value=question),
        gr.Textbox(label="Instructions (可选)", placeholder="请输入额外说明...", value=instructions),
        gr.Textbox(label="Create Table Statements", placeholder="请输入DDL语句...", value=schema),
    ],
    outputs="text",
)


if __name__ == "__main__":
    demo.launch(share=True)