Shahrokhpk commited on
Commit
32c1ac9
·
verified ·
1 Parent(s): dbeca7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -6
app.py CHANGED
@@ -1,8 +1,54 @@
1
- # Use a pipeline as a high-level helper
2
- from transformers import pipeline
3
 
4
- messages = [
5
- {"role": "user", "content": "Who are you?"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  ]
7
- pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b")
8
- pipe(messages)
 
 
 
1
 
2
+ nl2sqlite_template_cn = """You are a SQLite expert. Now you need to read and understand the following [database schema] description,
3
+ as well as the [reference information] that may be used, and use SQLite knowledge to generate SQL statements to answer [user questions].
4
+ [User question]
5
+ {question}
6
+
7
+ [Database schema]
8
+ {db_schema}
9
+
10
+ [Reference information]
11
+ {evidence}
12
+
13
+ [User question]
14
+ {question}
15
+
16
+ ```sql"""
17
+
18
+ import torch
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
+
21
+ model_name = "XGenerationLab/XiYanSQL-QwenCoder-3B-2502"
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.bfloat16,
25
+ device_map="auto"
26
+ )
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+
30
+ ## dialects -> ['SQLite', 'PostgreSQL', 'MySQL']
31
+ prompt = nl2sqlite_template_cn.format(dialect="", db_schema="", question="", evidence="")
32
+ message = [{'role': 'user', 'content': prompt}]
33
+
34
+ text = tokenizer.apply_chat_template(
35
+ message,
36
+ tokenize=False,
37
+ add_generation_prompt=True
38
+ )
39
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
40
+
41
+ generated_ids = model.generate(
42
+ **model_inputs,
43
+ pad_token_id=tokenizer.pad_token_id,
44
+ eos_token_id=tokenizer.eos_token_id,
45
+ max_new_tokens=1024,
46
+ temperature=0.1,
47
+ top_p=0.8,
48
+ do_sample=True,
49
+ )
50
+ generated_ids = [
51
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
52
  ]
53
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
54
+