File size: 2,832 Bytes
982b2c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
# import os
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TrainingArguments
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import warnings
# ํน์ ๊ฒฝ๊ณ ๋ฅผ ๋ฌด์ํ๋๋ก ์ค์
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")
# ๊ธฐ๊ธฐ ์ค์ (CPU ๋๋ MPS)
device = torch.device("cpu")
# ๋ฐ์ดํฐ ๋ก๋
data_path = "./data_finetunned/coffee_finetuning_20240914_witi_total.jsonl"
dataset = Dataset.from_json(data_path)
print("๋ฐ์ดํฐ์
๋ก๋ ์๋ฃ")
# ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋
BASE_MODEL = "beomi/gemma-ko-2b"
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
# Prompt ์์ฑ ํจ์ ์์ (instruction๊ณผ context ๊ธฐ๋ฐ์ผ๋ก ์์ฑ)
def generate_prompt(example):
prompt_list = []
for i in range(len(example['instruction'])):
prompt_list.append(f"""<bos><start_of_turn>user
{example['instruction'][i]}<end_of_turn>
<start_of_turn>model
{example['response'][i]}<end_of_turn><eos>""")
return prompt_list
# ๋ฐ์ดํฐ์
์ train ๋ฐ์ดํฐ๋ก ์ค์
train_data = dataset
# ์ฒซ ๋ฒ์งธ ๋ฐ์ดํฐ์ ํ๋กฌํํธ ํ์ธ
print(generate_prompt(train_data[:1])[0])
# LoRA ์ค์
lora_config = LoraConfig(
r=6,
lora_alpha=8,
lora_dropout=0.05,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# ๋ชจ๋ธ ์ค์
trainer = SFTTrainer(
model=model,
train_dataset=train_data,
max_seq_length=512,
args=TrainingArguments(
output_dir="outputs",
max_steps=3000,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
optim="adamw_torch",
warmup_steps=0.03,
learning_rate=2e-4,
fp16=False,
logging_steps=100,
push_to_hub=False,
report_to='none',
use_mps_device=False # CPU๋ก ์ค์
),
peft_config=lora_config,
formatting_func=generate_prompt, # ์๋ก์ด ํฌ๋งทํ
ํจ์ ์ ์ฉ
)
# ํ๋ จ ์์
trainer.train()
# ์ด๋ํฐ ๋ชจ๋ธ ์ ์ฅ
ADAPTER_MODEL = "lora_adapter"
trainer.model.save_pretrained(ADAPTER_MODEL)
# ์ต์ข
๋ชจ๋ธ ๋ณํฉ ๋ฐ ์ ์ฅ
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)
model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)
model = model.merge_and_unload()
model.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
print("๋ชจ๋ธ ์ ์ฅ ์๋ฃ")
# ํ ํฌ๋์ด์ ๋ฅผ ์ ์ฅํฉ๋๋ค.
tokenizer.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
print("tokenizer ์ ์ฅ ์๋ฃ") |