{ "cells": [ { "cell_type": "markdown", "id": "9e02c8d1-e653-41a5-a94f-e44c176dbcc5", "metadata": {}, "source": [ "# 1. 개발 환경 설정" ] }, { "cell_type": "markdown", "id": "9fa242e1-7689-4397-b410-d550e79246c3", "metadata": {}, "source": [ "### 1.1 필수 라이브러리 설치하기" ] }, { "cell_type": "code", "execution_count": null, "id": "3d405d7a-f2c9-4416-bf88-880812a2b8b5", "metadata": {}, "outputs": [], "source": [ "!pip3 install -q -U transformers==4.38.2\n", "!pip3 install -q -U datasets==2.18.0\n", "!pip3 install -q -U bitsandbytes==0.42.0\n", "!pip3 install -q -U peft==0.9.0\n", "!pip3 install -q -U trl==0.7.11\n", "!pip3 install -q -U accelerate==0.27.2" ] }, { "cell_type": "markdown", "id": "13fa79b6-4720-43d1-baae-41d834011c2c", "metadata": {}, "source": [ "### 1.2 Import modules" ] }, { "cell_type": "code", "execution_count": null, "id": "1d7a17e3-b9a1-4a46-8f6e-7710a37a93bf", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from datasets import Dataset, load_dataset\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments\n", "from peft import LoraConfig, PeftModel\n", "from trl import SFTTrainer" ] }, { "cell_type": "markdown", "id": "5b7f30d7-bfdf-49c5-8c2c-701ad6f15a80", "metadata": {}, "source": [ "### 1.3 Huggingface 로그인" ] }, { "cell_type": "code", "execution_count": null, "id": "6aa22976-7bdf-479d-8c5c-8ab890be537f", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "98848a84-680e-4527-bdaf-f5cd7d635348", "metadata": {}, "source": [ "# 2. Dataset 생성 및 준비" ] }, { "cell_type": "markdown", "id": "ceaa6125-b440-4458-b3dc-142aa7668110", "metadata": {}, "source": [ "### 2.1 데이터셋 로드" ] }, { "cell_type": "code", "execution_count": null, "id": "9031d1af-d554-4852-bae8-006721468543", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "dataset = load_dataset(\"daekeun-ml/naver-news-summarization-ko\")" ] }, { "cell_type": "markdown", "id": "9f89cfc2-2123-4e30-8440-c827c9705510", "metadata": {}, "source": [ "### 2.2 데이터셋 탐색" ] }, { "cell_type": "code", "execution_count": null, "id": "780a6768-c25e-4816-b944-52e95638ecb7", "metadata": {}, "outputs": [], "source": [ "dataset" ] }, { "cell_type": "markdown", "id": "4c59da51-bb41-44ea-bd62-9e9bcece871f", "metadata": {}, "source": [ "### 2.3 데이터셋 예시" ] }, { "cell_type": "code", "execution_count": null, "id": "95b66ad0-c0ab-4be4-8214-ad02f1b8ebc6", "metadata": {}, "outputs": [], "source": [ "dataset['train'][0]" ] }, { "cell_type": "markdown", "id": "745507f8-dda1-4f98-8814-0543af75401c", "metadata": {}, "source": [ "# 3. Gemma 모델의 한국어 요약 테스트" ] }, { "cell_type": "markdown", "id": "7a1be307-f676-4f54-8c7a-894abadfe3be", "metadata": {}, "source": [ "### 3.1 모델 로드" ] }, { "cell_type": "code", "execution_count": null, "id": "249d5ac1-78ed-48b3-a67a-402a45bc962c", "metadata": {}, "outputs": [], "source": [ "BASE_MODEL = \"google/gemma-2b-it\"\n", "\n", "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map={\"\":0})\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)" ] }, { "cell_type": "markdown", "id": "80ddcf5b-eaef-4852-9b9c-83799a08cc3e", "metadata": {}, "source": [ "### 3.2 Gemma-it의 프롬프트 형식" ] }, { "cell_type": "code", "execution_count": null, "id": "42076cb8-3f57-476f-8fe9-2e454bbe4235", "metadata": {}, "outputs": [], "source": [ "doc = dataset['train']['document'][0]" ] }, { "cell_type": "code", "execution_count": null, "id": "b2f19d96-8aad-425c-9c4c-7f6420bd7849", "metadata": {}, "outputs": [], "source": [ "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=512)" ] }, { "cell_type": "code", "execution_count": null, "id": "7dc8d3da-6060-4203-9346-953d8adfb680", "metadata": {}, "outputs": [], "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"다음 글을 요약해주세요 :\\n\\n{}\".format(doc)\n", " }\n", "]\n", "prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "9fa04590-01f2-4358-a68c-1eba8eeb5d3c", "metadata": {}, "outputs": [], "source": [ "prompt" ] }, { "cell_type": "markdown", "id": "666223ea-2308-4126-a56c-a57fcec65390", "metadata": {}, "source": [ "### 3.3 Gemma-it 추론" ] }, { "cell_type": "code", "execution_count": null, "id": "a61247af-ce20-47cb-ae80-5a3e40d299f1", "metadata": {}, "outputs": [], "source": [ "outputs = pipe(\n", " prompt,\n", " do_sample=True,\n", " temperature=0.2,\n", " top_k=50,\n", " top_p=0.95,\n", " add_special_tokens=True\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "df721816-9d14-4890-bc7f-a441b5c02481", "metadata": {}, "outputs": [], "source": [ "print(outputs[0][\"generated_text\"][len(prompt):])" ] }, { "cell_type": "markdown", "id": "187a1bfb-b47c-448e-8957-86c00cc1df02", "metadata": {}, "source": [ "# 4. Gemma 파인튜닝" ] }, { "cell_type": "markdown", "id": "cc7b19a9-5a04-4d67-8004-de31fe0897a7", "metadata": {}, "source": [ "#### 주의: Colab GPU 메모리 한계로 이전장 추론에서 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
 notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다" ] }, { "cell_type": "code", "execution_count": null, "id": "91bfe441-991f-4bb8-b9a3-a1d2e9fc509c", "metadata": {}, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "id": "0a886413-a19c-4966-9e07-ca8cdb23aa16", "metadata": {}, "source": [ "### 4.1 학습용 프롬프트 조정" ] }, { "cell_type": "code", "execution_count": null, "id": "a9e4cc4b-a094-4035-906e-3edface3a099", "metadata": {}, "outputs": [], "source": [ "def generate_prompt(example):\n", " prompt_list = []\n", " for i in range(len(example['document'])):\n", " prompt_list.append(r\"\"\"user\n", "다음 글을 요약해주세요:\n", "\n", "{}\n", "model\n", "{}\"\"\".format(example['document'][i], example['summary'][i]))\n", " return prompt_list" ] }, { "cell_type": "code", "execution_count": null, "id": "c45ab1ee-8146-4731-86ec-d673e9a67557", "metadata": {}, "outputs": [], "source": [ "train_data = dataset['train']\n", "print(generate_prompt(train_data[:1])[0])" ] }, { "cell_type": "markdown", "id": "1849b4c0-16f3-44f3-bb67-7022f226ec05", "metadata": {}, "source": [ "### 4.2 QLoRA 설정" ] }, { "cell_type": "code", "execution_count": null, "id": "5c085b4b-a471-4c5a-afe3-81e8e0c37756", "metadata": {}, "outputs": [], "source": [ "lora_config = LoraConfig(\n", " r=6,\n", " target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n", " task_type=\"CAUSAL_LM\",\n", ")\n", "\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.float16\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "e10bfd65-00f8-49b6-933c-a27ed4385373", "metadata": {}, "outputs": [], "source": [ "BASE_MODEL = \"google/gemma-2b-it\"\n", "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map=\"auto\", quantization_config=bnb_config)\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)\n", "tokenizer.padding_side = 'right'" ] }, { "cell_type": "markdown", "id": "90db62d4-05ef-41ad-ad7b-a9c734c1b67d", "metadata": {}, "source": [ "### 4.3 Trainer 실행" ] }, { "cell_type": "code", "execution_count": null, "id": "335301f3-c127-44e8-af43-1999e1844681", "metadata": { "scrolled": true }, "outputs": [], "source": [ "trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=train_data,\n", " max_seq_length=512,\n", " args=TrainingArguments(\n", " output_dir=\"outputs\",\n", "# num_train_epochs = 1,\n", " max_steps=3000,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " optim=\"paged_adamw_8bit\",\n", " warmup_steps=0.03,\n", " learning_rate=2e-4,\n", " fp16=True,\n", " logging_steps=100,\n", " push_to_hub=False,\n", " report_to='none',\n", " ),\n", " peft_config=lora_config,\n", " formatting_func=generate_prompt,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "82fd7e65-334d-4052-9ab5-3c8e71bf09a5", "metadata": { "scrolled": true }, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "id": "dca74e51-15ec-403a-90f1-4b7eeb2c723b", "metadata": {}, "source": [ "### 4.4 Finetuned Model 저장" ] }, { "cell_type": "code", "execution_count": null, "id": "f2bba87d-d95c-4a57-9eb1-c02d81ad7bfb", "metadata": {}, "outputs": [], "source": [ "ADAPTER_MODEL = \"lora_adapter\"\n", "\n", "trainer.model.save_pretrained(ADAPTER_MODEL)" ] }, { "cell_type": "code", "execution_count": null, "id": "6a9fcda0-1d7a-4443-9b1c-7d45490daafb", "metadata": {}, "outputs": [], "source": [ "!ls -alh lora_adapter" ] }, { "cell_type": "code", "execution_count": null, "id": "a9a2a6d7-ece4-472a-981f-fb6599d1d307", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)\n", "model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)\n", "\n", "model = model.merge_and_unload()\n", "model.save_pretrained('gemma-2b-it-sum-ko')" ] }, { "cell_type": "code", "execution_count": null, "id": "1a764bbc-069d-400c-bca4-09e799bf0fb0", "metadata": {}, "outputs": [], "source": [ "!ls -alh ./gemma-2b-it-sum-ko" ] }, { "cell_type": "markdown", "id": "84f2c237-71f4-47c2-bad4-181dadb6cc98", "metadata": {}, "source": [ "# 5. Gemma 한국어 요약 모델 추론" ] }, { "cell_type": "markdown", "id": "8587dfc7-cf7c-4072-a8f7-6ceb1e90a532", "metadata": {}, "source": [ "#### 주의: 마찬가지로 Colab GPU 메모리 한계로 학습 시 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
 notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다" ] }, { "cell_type": "code", "execution_count": null, "id": "906ed4dd-270f-4000-84de-ede6885c0be5", "metadata": {}, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "id": "78399236-63b5-41af-9cee-a7233e23a9db", "metadata": {}, "source": [ "### 5.1 Fine-tuned 모델 로드" ] }, { "cell_type": "code", "execution_count": null, "id": "76d5ba97-91ca-48c3-b9a2-ba9bea6d7b09", "metadata": {}, "outputs": [], "source": [ "BASE_MODEL = \"google/gemma-2b-it\"\n", "FINETUNE_MODEL = \"./gemma-2b-it-sum-ko\"\n", "\n", "finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={\"\":0})\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)" ] }, { "cell_type": "markdown", "id": "5c34718c-ce52-4d68-ac8c-c18b6483b15b", "metadata": {}, "source": [ "### 5.2 Fine-tuned 모델 추론" ] }, { "cell_type": "code", "execution_count": null, "id": "a0f0fc82-abaf-49df-9254-7ccee2e74d96", "metadata": { "scrolled": true }, "outputs": [], "source": [ "pipe_finetuned = pipeline(\"text-generation\", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)" ] }, { "cell_type": "code", "execution_count": null, "id": "2f915638-d859-446f-bc78-070650421ece", "metadata": {}, "outputs": [], "source": [ "doc = dataset['test']['document'][10]" ] }, { "cell_type": "code", "execution_count": null, "id": "396788e7-4b80-46d7-980f-38fcb892a94f", "metadata": {}, "outputs": [], "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"다음 글을 요약해주세요:\\n\\n{}\".format(doc)\n", " }\n", "]\n", "prompt = pipe_finetuned.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "03f1f711-0ba7-4087-8317-b0e7f4246aee", "metadata": {}, "outputs": [], "source": [ "outputs = pipe_finetuned(\n", " prompt,\n", " do_sample=True,\n", " temperature=0.2,\n", " top_k=50,\n", " top_p=0.95,\n", " add_special_tokens=True\n", ")\n", "print(outputs[0][\"generated_text\"][len(prompt):])" ] }, { "cell_type": "code", "execution_count": null, "id": "73cb6b26-f1d1-4b7b-ba16-1ff62689fb94", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }