dgjx commited on
Commit
dd3a93a
·
verified ·
1 Parent(s): 1c4b940

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -8,7 +8,17 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # 降低内存占用
9
 
10
  def generate_sql(user_question, create_table_statements, instructions=""):
11
- prompt = f"Generate a SQL query to answer this question: `{user_question}`\n{instructions}\n\nDDL statements:\n{create_table_statements}\n<|eot_id|>"
 
 
 
 
 
 
 
 
 
 
12
 
13
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
14
  outputs = model.generate(**inputs, max_length=150)
@@ -17,6 +27,9 @@ def generate_sql(user_question, create_table_statements, instructions=""):
17
  return sql_query
18
 
19
 
 
 
 
20
  demo = gr.Interface(
21
  fn=generate_sql,
22
  inputs=[
 
8
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # 降低内存占用
9
 
10
  def generate_sql(user_question, create_table_statements, instructions=""):
11
+ prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
12
+
13
+ Generate a SQL query to answer this question: `{user_question}`
14
+ {instructions}
15
+
16
+ DDL statements:
17
+ {create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
18
+
19
+ The following SQL query best answers the question `{user_question}`:
20
+ ```sql
21
+ """
22
 
23
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
24
  outputs = model.generate(**inputs, max_length=150)
 
27
  return sql_query
28
 
29
 
30
+
31
+
32
+
33
  demo = gr.Interface(
34
  fn=generate_sql,
35
  inputs=[