{
"cells": [
{
"cell_type": "markdown",
"id": "9e02c8d1-e653-41a5-a94f-e44c176dbcc5",
"metadata": {},
"source": [
"# 1. 개발 환경 설정"
]
},
{
"cell_type": "markdown",
"id": "9fa242e1-7689-4397-b410-d550e79246c3",
"metadata": {},
"source": [
"### 1.1 필수 라이브러리 설치하기"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d405d7a-f2c9-4416-bf88-880812a2b8b5",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install -q -U transformers==4.38.2\n",
"!pip3 install -q -U datasets==2.18.0\n",
"!pip3 install -q -U bitsandbytes==0.42.0\n",
"!pip3 install -q -U peft==0.9.0\n",
"!pip3 install -q -U trl==0.7.11\n",
"!pip3 install -q -U accelerate==0.27.2"
]
},
{
"cell_type": "markdown",
"id": "13fa79b6-4720-43d1-baae-41d834011c2c",
"metadata": {},
"source": [
"### 1.2 Import modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d7a17e3-b9a1-4a46-8f6e-7710a37a93bf",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from datasets import Dataset, load_dataset\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments\n",
"from peft import LoraConfig, PeftModel\n",
"from trl import SFTTrainer"
]
},
{
"cell_type": "markdown",
"id": "5b7f30d7-bfdf-49c5-8c2c-701ad6f15a80",
"metadata": {},
"source": [
"### 1.3 Huggingface 로그인"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6aa22976-7bdf-479d-8c5c-8ab890be537f",
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"id": "98848a84-680e-4527-bdaf-f5cd7d635348",
"metadata": {},
"source": [
"# 2. Dataset 생성 및 준비"
]
},
{
"cell_type": "markdown",
"id": "ceaa6125-b440-4458-b3dc-142aa7668110",
"metadata": {},
"source": [
"### 2.1 데이터셋 로드"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9031d1af-d554-4852-bae8-006721468543",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"dataset = load_dataset(\"daekeun-ml/naver-news-summarization-ko\")"
]
},
{
"cell_type": "markdown",
"id": "9f89cfc2-2123-4e30-8440-c827c9705510",
"metadata": {},
"source": [
"### 2.2 데이터셋 탐색"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "780a6768-c25e-4816-b944-52e95638ecb7",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"id": "4c59da51-bb41-44ea-bd62-9e9bcece871f",
"metadata": {},
"source": [
"### 2.3 데이터셋 예시"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95b66ad0-c0ab-4be4-8214-ad02f1b8ebc6",
"metadata": {},
"outputs": [],
"source": [
"dataset['train'][0]"
]
},
{
"cell_type": "markdown",
"id": "745507f8-dda1-4f98-8814-0543af75401c",
"metadata": {},
"source": [
"# 3. Gemma 모델의 한국어 요약 테스트"
]
},
{
"cell_type": "markdown",
"id": "7a1be307-f676-4f54-8c7a-894abadfe3be",
"metadata": {},
"source": [
"### 3.1 모델 로드"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "249d5ac1-78ed-48b3-a67a-402a45bc962c",
"metadata": {},
"outputs": [],
"source": [
"BASE_MODEL = \"google/gemma-2b-it\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map={\"\":0})\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)"
]
},
{
"cell_type": "markdown",
"id": "80ddcf5b-eaef-4852-9b9c-83799a08cc3e",
"metadata": {},
"source": [
"### 3.2 Gemma-it의 프롬프트 형식"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42076cb8-3f57-476f-8fe9-2e454bbe4235",
"metadata": {},
"outputs": [],
"source": [
"doc = dataset['train']['document'][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2f19d96-8aad-425c-9c4c-7f6420bd7849",
"metadata": {},
"outputs": [],
"source": [
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=512)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7dc8d3da-6060-4203-9346-953d8adfb680",
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"다음 글을 요약해주세요 :\\n\\n{}\".format(doc)\n",
" }\n",
"]\n",
"prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fa04590-01f2-4358-a68c-1eba8eeb5d3c",
"metadata": {},
"outputs": [],
"source": [
"prompt"
]
},
{
"cell_type": "markdown",
"id": "666223ea-2308-4126-a56c-a57fcec65390",
"metadata": {},
"source": [
"### 3.3 Gemma-it 추론"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a61247af-ce20-47cb-ae80-5a3e40d299f1",
"metadata": {},
"outputs": [],
"source": [
"outputs = pipe(\n",
" prompt,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" top_k=50,\n",
" top_p=0.95,\n",
" add_special_tokens=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df721816-9d14-4890-bc7f-a441b5c02481",
"metadata": {},
"outputs": [],
"source": [
"print(outputs[0][\"generated_text\"][len(prompt):])"
]
},
{
"cell_type": "markdown",
"id": "187a1bfb-b47c-448e-8957-86c00cc1df02",
"metadata": {},
"source": [
"# 4. Gemma 파인튜닝"
]
},
{
"cell_type": "markdown",
"id": "cc7b19a9-5a04-4d67-8004-de31fe0897a7",
"metadata": {},
"source": [
"#### 주의: Colab GPU 메모리 한계로 이전장 추론에서 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91bfe441-991f-4bb8-b9a3-a1d2e9fc509c",
"metadata": {},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"id": "0a886413-a19c-4966-9e07-ca8cdb23aa16",
"metadata": {},
"source": [
"### 4.1 학습용 프롬프트 조정"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9e4cc4b-a094-4035-906e-3edface3a099",
"metadata": {},
"outputs": [],
"source": [
"def generate_prompt(example):\n",
" prompt_list = []\n",
" for i in range(len(example['document'])):\n",
" prompt_list.append(r\"\"\"user\n",
"다음 글을 요약해주세요:\n",
"\n",
"{}\n",
"model\n",
"{}\"\"\".format(example['document'][i], example['summary'][i]))\n",
" return prompt_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c45ab1ee-8146-4731-86ec-d673e9a67557",
"metadata": {},
"outputs": [],
"source": [
"train_data = dataset['train']\n",
"print(generate_prompt(train_data[:1])[0])"
]
},
{
"cell_type": "markdown",
"id": "1849b4c0-16f3-44f3-bb67-7022f226ec05",
"metadata": {},
"source": [
"### 4.2 QLoRA 설정"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c085b4b-a471-4c5a-afe3-81e8e0c37756",
"metadata": {},
"outputs": [],
"source": [
"lora_config = LoraConfig(\n",
" r=6,\n",
" target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
" task_type=\"CAUSAL_LM\",\n",
")\n",
"\n",
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.float16\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e10bfd65-00f8-49b6-933c-a27ed4385373",
"metadata": {},
"outputs": [],
"source": [
"BASE_MODEL = \"google/gemma-2b-it\"\n",
"model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map=\"auto\", quantization_config=bnb_config)\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)\n",
"tokenizer.padding_side = 'right'"
]
},
{
"cell_type": "markdown",
"id": "90db62d4-05ef-41ad-ad7b-a9c734c1b67d",
"metadata": {},
"source": [
"### 4.3 Trainer 실행"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "335301f3-c127-44e8-af43-1999e1844681",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"trainer = SFTTrainer(\n",
" model=model,\n",
" train_dataset=train_data,\n",
" max_seq_length=512,\n",
" args=TrainingArguments(\n",
" output_dir=\"outputs\",\n",
"# num_train_epochs = 1,\n",
" max_steps=3000,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" optim=\"paged_adamw_8bit\",\n",
" warmup_steps=0.03,\n",
" learning_rate=2e-4,\n",
" fp16=True,\n",
" logging_steps=100,\n",
" push_to_hub=False,\n",
" report_to='none',\n",
" ),\n",
" peft_config=lora_config,\n",
" formatting_func=generate_prompt,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82fd7e65-334d-4052-9ab5-3c8e71bf09a5",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "dca74e51-15ec-403a-90f1-4b7eeb2c723b",
"metadata": {},
"source": [
"### 4.4 Finetuned Model 저장"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2bba87d-d95c-4a57-9eb1-c02d81ad7bfb",
"metadata": {},
"outputs": [],
"source": [
"ADAPTER_MODEL = \"lora_adapter\"\n",
"\n",
"trainer.model.save_pretrained(ADAPTER_MODEL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a9fcda0-1d7a-4443-9b1c-7d45490daafb",
"metadata": {},
"outputs": [],
"source": [
"!ls -alh lora_adapter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9a2a6d7-ece4-472a-981f-fb6599d1d307",
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)\n",
"model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)\n",
"\n",
"model = model.merge_and_unload()\n",
"model.save_pretrained('gemma-2b-it-sum-ko')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a764bbc-069d-400c-bca4-09e799bf0fb0",
"metadata": {},
"outputs": [],
"source": [
"!ls -alh ./gemma-2b-it-sum-ko"
]
},
{
"cell_type": "markdown",
"id": "84f2c237-71f4-47c2-bad4-181dadb6cc98",
"metadata": {},
"source": [
"# 5. Gemma 한국어 요약 모델 추론"
]
},
{
"cell_type": "markdown",
"id": "8587dfc7-cf7c-4072-a8f7-6ceb1e90a532",
"metadata": {},
"source": [
"#### 주의: 마찬가지로 Colab GPU 메모리 한계로 학습 시 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "906ed4dd-270f-4000-84de-ede6885c0be5",
"metadata": {},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"id": "78399236-63b5-41af-9cee-a7233e23a9db",
"metadata": {},
"source": [
"### 5.1 Fine-tuned 모델 로드"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76d5ba97-91ca-48c3-b9a2-ba9bea6d7b09",
"metadata": {},
"outputs": [],
"source": [
"BASE_MODEL = \"google/gemma-2b-it\"\n",
"FINETUNE_MODEL = \"./gemma-2b-it-sum-ko\"\n",
"\n",
"finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={\"\":0})\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)"
]
},
{
"cell_type": "markdown",
"id": "5c34718c-ce52-4d68-ac8c-c18b6483b15b",
"metadata": {},
"source": [
"### 5.2 Fine-tuned 모델 추론"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0f0fc82-abaf-49df-9254-7ccee2e74d96",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"pipe_finetuned = pipeline(\"text-generation\", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f915638-d859-446f-bc78-070650421ece",
"metadata": {},
"outputs": [],
"source": [
"doc = dataset['test']['document'][10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "396788e7-4b80-46d7-980f-38fcb892a94f",
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"다음 글을 요약해주세요:\\n\\n{}\".format(doc)\n",
" }\n",
"]\n",
"prompt = pipe_finetuned.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03f1f711-0ba7-4087-8317-b0e7f4246aee",
"metadata": {},
"outputs": [],
"source": [
"outputs = pipe_finetuned(\n",
" prompt,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" top_k=50,\n",
" top_p=0.95,\n",
" add_special_tokens=True\n",
")\n",
"print(outputs[0][\"generated_text\"][len(prompt):])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73cb6b26-f1d1-4b7b-ba16-1ff62689fb94",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}