Gamoooo commited on
Commit
cee8f38
·
verified ·
1 Parent(s): 0547dd0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +229 -0
README.md CHANGED
@@ -20,3 +20,232 @@ language:
20
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
23
+
24
+ !pip uninstall unsloth -y
25
+ !pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
26
+ !pip install --upgrade torch
27
+ !pip install --upgrade xformers
28
+
29
+ # Install Flash Attention 2 for softcapping support
30
+ import torch
31
+ if torch.cuda.get_device_capability()[0] >= 8:
32
+ !pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"
33
+
34
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
35
+ from unsloth import FastLanguageModel
36
+ import torch
37
+
38
+ max_seq_length = 512
39
+ dtype = None
40
+ load_in_4bit = True
41
+
42
+ model_id = "llm-jp/llm-jp-3-13b"
43
+ new_model_id = "llm-jp-3-13b-last"
44
+
45
+ model, tokenizer = FastLanguageModel.from_pretrained(
46
+ model_name=model_id,
47
+ dtype=dtype,
48
+ load_in_4bit=load_in_4bit,
49
+ trust_remote_code=True,
50
+ )
51
+
52
+ # SFT用のモデルを用意
53
+ model = FastLanguageModel.get_peft_model(
54
+ model,
55
+ r=32,
56
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
57
+ "gate_proj", "up_proj", "down_proj"],
58
+ lora_alpha=32,
59
+ lora_dropout=0.05,
60
+ bias="none",
61
+ use_gradient_checkpointing="unsloth",
62
+ random_state=3407,
63
+ use_rslora=False,
64
+ loftq_config=None,
65
+ max_seq_length=max_seq_length,
66
+ )
67
+
68
+ # https://huggingface.co/settings/tokens
69
+ HF_TOKEN = "your-token" # @param {type:"string"}
70
+
71
+ from datasets import load_dataset, concatenate_datasets
72
+
73
+ # データセットのロード
74
+ ichikara_dataset = load_dataset("json", data_files="/content/ichikara-instruction-003-001-1.json")
75
+ elyza_dataset = load_dataset("elyza/ELYZA-tasks-100")
76
+
77
+ EOS_TOKEN = tokenizer.eos_token #
78
+
79
+ # 学習時のプロンプトフォーマットの定義
80
+ prompt = """### 指示
81
+ {}
82
+ ### 回答
83
+ {}"""
84
+
85
+ """
86
+ formatting_prompts_func: 各データをプロンプトに合わせた形式に合わせる
87
+ """
88
+ def formatting_prompts_func(examples):
89
+ input = examples["text"]
90
+ output = examples["output"]
91
+ text = prompt.format(input, output) + EOS_TOKEN
92
+ return {"formatted_text": text}
93
+
94
+ # ichikara-instruction のデータフォーマット
95
+ ichikara_dataset = ichikara_dataset.map(
96
+ formatting_prompts_func,
97
+ num_proc=4,
98
+ )
99
+
100
+ # ELYZA-tasks-100 データセットのフォーマット関数
101
+ def elyza_formatting_prompts_func(examples):
102
+ input = examples["input"]
103
+ output = examples["output"]
104
+ text = prompt.format(input, output) + EOS_TOKEN
105
+ return {"formatted_text": text}
106
+
107
+ # ELYZA-tasks-100 のデータフォーマット
108
+ elyza_dataset = elyza_dataset.map(
109
+ elyza_formatting_prompts_func,
110
+ num_proc=4
111
+ )
112
+
113
+ from datasets import concatenate_datasets
114
+
115
+ # ichikara-instruction と ELYZA-tasks-100 を統合
116
+ combined_dataset = concatenate_datasets([
117
+ ichikara_dataset["train"],
118
+ elyza_dataset["test"]
119
+ ])
120
+
121
+ # データ品質チェック
122
+ # 1. ランダムサンプルを確認
123
+ import random
124
+ sample_indices = random.sample(range(len(combined_dataset)), 10)
125
+ for idx in sample_indices:
126
+ print(combined_dataset[idx]["formatted_text"])
127
+
128
+ # 2. 自動検査ルール
129
+ # 短すぎるデータをチェック(Noneチェックを追加)
130
+ short_data = combined_dataset.filter(
131
+ lambda x: x["input"] is not None and x["output"] is not None and (len(x["input"]) < 5 or len(x["output"]) < 5)
132
+ )
133
+ print(f"\n短すぎるデータ数: {len(short_data)}")
134
+
135
+ # 指示と回答が同一のデータ(Noneチェックを追加)
136
+ duplicate_data = combined_dataset.filter(
137
+ lambda x: x["input"] is not None and x["output"] is not None and x["input"].strip() == x["output"].strip()
138
+ )
139
+ print(f"\n指示と回答が同一のデータ数: {len(duplicate_data)}")
140
+
141
+ # 問題のあるデータをフィルタリング(Noneチェックを追加)
142
+ filtered_dataset = combined_dataset.filter(
143
+ lambda x: x["input"] is not None and x["output"] is not None and len(x["input"]) > 5 and len(x["output"]) > 5 and x["input"].strip() != x["output"].strip()
144
+ )
145
+
146
+ print(f"元のデータ数: {len(combined_dataset)}")
147
+ print(f"フィルタリング後のデータ数: {len(filtered_dataset)}")
148
+ print(f"除外されたデータ数: {len(combined_dataset) - len(filtered_dataset)}")
149
+
150
+ # フィルタリング後のデータの例を確認
151
+ print(filtered_dataset[0])
152
+
153
+ """
154
+ training_arguments: 学習の設定
155
+ """
156
+ from trl import SFTTrainer
157
+ from transformers import TrainingArguments
158
+ from unsloth import is_bfloat16_supported
159
+
160
+ trainer = SFTTrainer(
161
+ model=model,
162
+ tokenizer=tokenizer,
163
+ train_dataset=filtered_dataset,
164
+ max_seq_length=max_seq_length,
165
+ dataset_text_field="formatted_text",
166
+ packing=False,
167
+ args=TrainingArguments(
168
+ per_device_train_batch_size=2,
169
+ gradient_accumulation_steps=4,
170
+ num_train_epochs=3,
171
+ logging_steps=10,
172
+ warmup_steps=10,
173
+ save_steps=50,
174
+ save_total_limit=2,
175
+ max_steps=200,
176
+ learning_rate=2e-4,
177
+ fp16=not is_bfloat16_supported(),
178
+ bf16=is_bfloat16_supported(),
179
+ group_by_length=True,
180
+ seed=3407,
181
+ output_dir="outputs",
182
+ report_to="none",
183
+ ),
184
+ )
185
+
186
+ #@title 学習実行
187
+ trainer_stats = trainer.train()
188
+
189
+ import json
190
+ from datasets import load_dataset
191
+
192
+ dataset = load_dataset("json", data_files="/content/elyza-tasks-100-TV_0.jsonl", split="train")
193
+
194
+ datasets = []
195
+ with open("/content/elyza-tasks-100-TV_0.jsonl", "r", encoding="utf-8") as f:
196
+ item = ""
197
+ for line in f:
198
+ line = line.strip()
199
+ item += line
200
+ if item.endswith("}"):
201
+ datasets.append(json.loads(item))
202
+ item = ""
203
+
204
+ from tqdm import tqdm
205
+ import json
206
+
207
+ # 推論するためにモデルのモードを変更
208
+ FastLanguageModel.for_inference(model)
209
+
210
+ results = []
211
+ for dt in tqdm(datasets):
212
+ try:
213
+ input_text = dt["input"]
214
+
215
+ # プロンプトを生成
216
+ prompt = f"### 指示\n{input_text}\n次の要件を満たしてください:\n1. 簡潔に回答する。\n2. 必要なら箇条書きを使用して要点を整理する。\n3. 指示された内容に忠実に答える。\n### 回答\n"
217
+
218
+
219
+ # トークナイズ
220
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
221
+
222
+ # 推論
223
+ outputs = model.generate(
224
+ **inputs,
225
+ max_new_tokens=512,
226
+ use_cache=True,
227
+ do_sample=False,
228
+ repetition_penalty=1.2,
229
+ )
230
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### 回答')[-1]
231
+
232
+ # 結果を保存
233
+ results.append({"task_id": dt["task_id"], "input": input_text, "output": prediction})
234
+ except Exception as e:
235
+ print(f"Error processing task_id {dt.get('task_id', 'Unknown')}: {e}")
236
+ results.append({"task_id": dt.get("task_id", "Unknown"), "input": dt.get("input", ""), "output": "Error"})
237
+
238
+
239
+ # 結果をJSONL形式で保存
240
+ output_file_jsonl = "/content/llm-jp-3-13b-last.jsonl"
241
+ with open(output_file_jsonl, "w", encoding="utf-8") as f:
242
+ for result in results:
243
+ f.write(json.dumps(result, ensure_ascii=False) + "\n")
244
+
245
+ model.push_to_hub_merged(
246
+ new_model_id,
247
+ tokenizer=tokenizer,
248
+ save_method="lora",
249
+ token=HF_TOKEN,
250
+ private=True
251
+ )