BEANs / gemma_Ko_coffee.py
joeykims's picture
Feat: gemma coffee model
982b2c3
import torch
# import os
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TrainingArguments
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import warnings
# ํŠน์ • ๊ฒฝ๊ณ ๋ฅผ ๋ฌด์‹œํ•˜๋„๋ก ์„ค์ •
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")
# ๊ธฐ๊ธฐ ์„ค์ • (CPU ๋˜๋Š” MPS)
device = torch.device("cpu")
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
data_path = "./data_finetunned/coffee_finetuning_20240914_witi_total.jsonl"
dataset = Dataset.from_json(data_path)
print("๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์™„๋ฃŒ")
# ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
BASE_MODEL = "beomi/gemma-ko-2b"
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
# Prompt ์ƒ์„ฑ ํ•จ์ˆ˜ ์ˆ˜์ • (instruction๊ณผ context ๊ธฐ๋ฐ˜์œผ๋กœ ์ƒ์„ฑ)
def generate_prompt(example):
prompt_list = []
for i in range(len(example['instruction'])):
prompt_list.append(f"""<bos><start_of_turn>user
{example['instruction'][i]}<end_of_turn>
<start_of_turn>model
{example['response'][i]}<end_of_turn><eos>""")
return prompt_list
# ๋ฐ์ดํ„ฐ์…‹์„ train ๋ฐ์ดํ„ฐ๋กœ ์„ค์ •
train_data = dataset
# ์ฒซ ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ์˜ ํ”„๋กฌํ”„ํŠธ ํ™•์ธ
print(generate_prompt(train_data[:1])[0])
# LoRA ์„ค์ •
lora_config = LoraConfig(
r=6,
lora_alpha=8,
lora_dropout=0.05,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# ๋ชจ๋ธ ์„ค์ •
trainer = SFTTrainer(
model=model,
train_dataset=train_data,
max_seq_length=512,
args=TrainingArguments(
output_dir="outputs",
max_steps=3000,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
optim="adamw_torch",
warmup_steps=0.03,
learning_rate=2e-4,
fp16=False,
logging_steps=100,
push_to_hub=False,
report_to='none',
use_mps_device=False # CPU๋กœ ์„ค์ •
),
peft_config=lora_config,
formatting_func=generate_prompt, # ์ƒˆ๋กœ์šด ํฌ๋งทํŒ… ํ•จ์ˆ˜ ์ ์šฉ
)
# ํ›ˆ๋ จ ์‹œ์ž‘
trainer.train()
# ์–ด๋Œ‘ํ„ฐ ๋ชจ๋ธ ์ €์žฅ
ADAPTER_MODEL = "lora_adapter"
trainer.model.save_pretrained(ADAPTER_MODEL)
# ์ตœ์ข… ๋ชจ๋ธ ๋ณ‘ํ•ฉ ๋ฐ ์ €์žฅ
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)
model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)
model = model.merge_and_unload()
model.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
print("๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ")
# ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
tokenizer.save_pretrained('./gemma_outputs/gemma-ko-2b-beans-20240915-01')
print("tokenizer ์ €์žฅ ์™„๋ฃŒ")