ytyeung commited on
Commit
897bd83
·
1 Parent(s): 4563580

SFT training code

Browse files
Files changed (1) hide show
  1. sft_train.py +152 -0
sft_train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from datasets import Dataset, load_dataset
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
5
+ from peft import LoraConfig, TaskType, get_peft_model, AutoPeftModelForCausalLM
6
+
7
+ # %%
8
+ df = pd.read_csv('data/riddles_data.csv')
9
+ df =df.sample(frac = 1)
10
+ #df = df[:1000]
11
+
12
+ # %%
13
+ df.describe()
14
+
15
+ # %%
16
+ ds = Dataset.from_pandas(df)
17
+
18
+ # %%
19
+ ds[:3]
20
+
21
+ # %%
22
+ llm_model_name="Qwen/Qwen1.5-0.5B-Chat"
23
+ model = AutoModelForCausalLM.from_pretrained(llm_model_name)
24
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_name,trust_remote_code=True, pad_token='<|endoftext|>')
25
+
26
+ tokenizer
27
+
28
+
29
+ # %%
30
+ def process_func(example):
31
+ MAX_LENGTH = 512
32
+ input_ids, attention_mask, labels = [], [], []
33
+ instruction = tokenizer(f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n猜谜语:\n谜面:{example['riddle']}\n\n谜底是什么?<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens
34
+ response = tokenizer(f"谜底是:{example['label']}", add_special_tokens=False)
35
+ input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
36
+ attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
37
+ labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
38
+ if len(input_ids) > MAX_LENGTH: # 做一个截断
39
+ input_ids = input_ids[:MAX_LENGTH]
40
+ attention_mask = attention_mask[:MAX_LENGTH]
41
+ labels = labels[:MAX_LENGTH]
42
+ print (f"{tokenizer.decode(input_ids)} Too Long")
43
+ return {
44
+ "input_ids": input_ids,
45
+ "attention_mask": attention_mask,
46
+ "labels": labels
47
+ }
48
+
49
+ # %%
50
+ tokenized_id = ds.map(process_func, remove_columns=ds.column_names)
51
+ tokenized_id
52
+
53
+ # %%
54
+ tokenizer.decode(tokenized_id[0]['input_ids'])
55
+
56
+ # %%
57
+ tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1]["labels"])))
58
+
59
+ # %%
60
+ config = LoraConfig(
61
+ task_type=TaskType.CAUSAL_LM,
62
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
63
+ #target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
64
+ inference_mode=False,
65
+ r=32,
66
+ lora_alpha=32,
67
+ lora_dropout=0.05
68
+ )
69
+
70
+
71
+ # %%
72
+ model = get_peft_model(model, config)
73
+ config
74
+
75
+ # %%
76
+ model.print_trainable_parameters()
77
+
78
+ # %%
79
+ args = TrainingArguments(
80
+ output_dir="./Qwen1.5_0.5B_Chat_sft_full",
81
+ logging_steps=10,
82
+ num_train_epochs=2,
83
+ save_steps=10,
84
+ learning_rate=1e-4,
85
+ save_on_each_node=True,
86
+ fp16=False
87
+ )
88
+
89
+ # %%
90
+ trainer = Trainer(
91
+ model=model,
92
+ args=args,
93
+ train_dataset=tokenized_id,
94
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
95
+ )
96
+
97
+ trainer.train(resume_from_checkpoint=True)
98
+
99
+ # %%
100
+ trainer.save_model("./qwen_sft_full")
101
+
102
+ # %%
103
+ llm_model_name="Qwen/Qwen1.5-0.5B-Chat"
104
+
105
+ #model = AutoModelForCausalLM.from_pretrained(llm_model_name)
106
+ # # Load PEFT model on CPU
107
+ model = AutoPeftModelForCausalLM.from_pretrained(
108
+ "Qwen1.5_0.5B_Chat_sft_full_ckpt_200_ok/checkpoint-210",
109
+ #low_cpu_mem_usage=True,
110
+ )
111
+ # # Merge LoRA and base model and save
112
+ #merged_model = model.merge_and_unload()
113
+ #merged_model.save_pretrained("./qwen_sft",safe_serialization=False, max_shard_size="2GB")
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_name,trust_remote_code=True, pad_token='<|endoftext|>')
116
+
117
+
118
+ # %%
119
+ prompt = "谜面:一生受用(猜一字)\n谜底是什么?请解释。"
120
+ messages = [
121
+ {"role": "system", "content": "You are a helpful assistant."},
122
+ {"role": "user", "content": prompt}
123
+ ]
124
+ text = tokenizer.apply_chat_template(
125
+ messages,
126
+ tokenize=False,
127
+ add_generation_prompt=True
128
+ )
129
+
130
+ print(text)
131
+ model_inputs = tokenizer([text], return_tensors="pt").to("cpu")
132
+
133
+ generated_ids = model.generate(
134
+ model_inputs.input_ids,
135
+ max_new_tokens=128,
136
+ do_sample=False,
137
+ top_p=0.0
138
+ )
139
+
140
+ generated_ids = [
141
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
142
+ ]
143
+
144
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
145
+
146
+ # %%
147
+ response
148
+
149
+ # %%
150
+
151
+
152
+