dgjx commited on
Commit
b2d9f7b
·
verified ·
1 Parent(s): dd3a93a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -4
app.py CHANGED
@@ -7,7 +7,7 @@ model_name = "defog/sqlcoder-7b-2" # 使用更新的模型以提高性能
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # 降低内存占用
9
 
10
- def generate_sql(user_question, create_table_statements, instructions=""):
11
  prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
12
 
13
  Generate a SQL query to answer this question: `{user_question}`
@@ -27,16 +27,57 @@ The following SQL query best answers the question `{user_question}`:
27
  return sql_query
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  demo = gr.Interface(
34
  fn=generate_sql,
35
  inputs=[
36
  gr.Markdown("## SQL Query Generator"),
37
- gr.Textbox(label="User Question", placeholder="请输入您的问题...", value="从纽约的客户那里获得的总收入是多少?"),
38
- gr.Textbox(label="Create Table Statements", placeholder="请输入DDL语句...", value="CREATE TABLE customers (id INT, city VARCHAR(50), revenue DECIMAL);"),
39
- gr.Textbox(label="Instructions (可选)", placeholder="请输入额外说明...", value="")
40
  ],
41
  outputs="text",
42
  )
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # 降低内存占用
9
 
10
+ def generate_sql(user_question, instructions, create_table_statements):
11
  prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
12
 
13
  Generate a SQL query to answer this question: `{user_question}`
 
27
  return sql_query
28
 
29
 
30
+ question = f"What are our top 3 products by revenue in the New York region?"
31
+ instructions = f"""- if the question cannot be answered given the database schema, return "I do not know"
32
+ - recall that the current date in YYYY-MM-DD format is 2024-09-15
33
+ """
34
+ schema = f"""CREATE TABLE products (
35
+ product_id INTEGER PRIMARY KEY, -- Unique ID for each product
36
+ name VARCHAR(50), -- Name of the product
37
+ price DECIMAL(10,2), -- Price of each unit of the product
38
+ quantity INTEGER -- Current quantity in stock
39
+ );
40
+
41
+ CREATE TABLE customers (
42
+ customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
43
+ name VARCHAR(50), -- Name of the customer
44
+ address VARCHAR(100) -- Mailing address of the customer
45
+ );
46
+
47
+ CREATE TABLE salespeople (
48
+ salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
49
+ name VARCHAR(50), -- Name of the salesperson
50
+ region VARCHAR(50) -- Geographic sales region
51
+ );
52
 
53
+ CREATE TABLE sales (
54
+ sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
55
+ product_id INTEGER, -- ID of product sold
56
+ customer_id INTEGER, -- ID of customer who made purchase
57
+ salesperson_id INTEGER, -- ID of salesperson who made the sale
58
+ sale_date DATE, -- Date the sale occurred
59
+ quantity INTEGER -- Quantity of product sold
60
+ );
61
 
62
+ CREATE TABLE product_suppliers (
63
+ supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
64
+ product_id INTEGER, -- Product ID supplied
65
+ supply_price DECIMAL(10,2) -- Unit price charged by supplier
66
+ );
67
+
68
+ -- sales.product_id can be joined with products.product_id
69
+ -- sales.customer_id can be joined with customers.customer_id
70
+ -- sales.salesperson_id can be joined with salespeople.salesperson_id
71
+ -- product_suppliers.product_id can be joined with products.product_id
72
+ """
73
 
74
  demo = gr.Interface(
75
  fn=generate_sql,
76
  inputs=[
77
  gr.Markdown("## SQL Query Generator"),
78
+ gr.Textbox(label="User Question", placeholder="请输入您的问题...", value=question),
79
+ gr.Textbox(label="Instructions (可选)", placeholder="请输入额外说明...", value=instructions),
80
+ gr.Textbox(label="Create Table Statements", placeholder="请输入DDL语句...", value=schema),
81
  ],
82
  outputs="text",
83
  )