Safetensors
Korean
gemma
sususupa commited on
Commit
aa37d24
โ€ข
1 Parent(s): eb1ddc7

Upload gemma-2b-it-sum-ko.ipynb

Browse files
Files changed (1) hide show
  1. gemma-2b-it-sum-ko.ipynb +609 -0
gemma-2b-it-sum-ko.ipynb ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9e02c8d1-e653-41a5-a94f-e44c176dbcc5",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 1. ๊ฐœ๋ฐœ ํ™˜๊ฒฝ ์„ค์ •"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "9fa242e1-7689-4397-b410-d550e79246c3",
14
+ "metadata": {},
15
+ "source": [
16
+ "### 1.1 ํ•„์ˆ˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ํ•˜๊ธฐ"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "3d405d7a-f2c9-4416-bf88-880812a2b8b5",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "!pip3 install -q -U transformers==4.38.2\n",
27
+ "!pip3 install -q -U datasets==2.18.0\n",
28
+ "!pip3 install -q -U bitsandbytes==0.42.0\n",
29
+ "!pip3 install -q -U peft==0.9.0\n",
30
+ "!pip3 install -q -U trl==0.7.11\n",
31
+ "!pip3 install -q -U accelerate==0.27.2"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "13fa79b6-4720-43d1-baae-41d834011c2c",
37
+ "metadata": {},
38
+ "source": [
39
+ "### 1.2 Import modules"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "1d7a17e3-b9a1-4a46-8f6e-7710a37a93bf",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "import torch\n",
50
+ "from datasets import Dataset, load_dataset\n",
51
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments\n",
52
+ "from peft import LoraConfig, PeftModel\n",
53
+ "from trl import SFTTrainer"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "id": "5b7f30d7-bfdf-49c5-8c2c-701ad6f15a80",
59
+ "metadata": {},
60
+ "source": [
61
+ "### 1.3 Huggingface ๋กœ๊ทธ์ธ"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "6aa22976-7bdf-479d-8c5c-8ab890be537f",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "from huggingface_hub import notebook_login\n",
72
+ "notebook_login()"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "id": "98848a84-680e-4527-bdaf-f5cd7d635348",
78
+ "metadata": {},
79
+ "source": [
80
+ "# 2. Dataset ์ƒ์„ฑ ๋ฐ ์ค€๋น„"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "id": "ceaa6125-b440-4458-b3dc-142aa7668110",
86
+ "metadata": {},
87
+ "source": [
88
+ "### 2.1 ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "9031d1af-d554-4852-bae8-006721468543",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "from datasets import load_dataset\n",
99
+ "dataset = load_dataset(\"daekeun-ml/naver-news-summarization-ko\")"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "id": "9f89cfc2-2123-4e30-8440-c827c9705510",
105
+ "metadata": {},
106
+ "source": [
107
+ "### 2.2 ๋ฐ์ดํ„ฐ์…‹ ํƒ์ƒ‰"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "780a6768-c25e-4816-b944-52e95638ecb7",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "dataset"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "id": "4c59da51-bb41-44ea-bd62-9e9bcece871f",
123
+ "metadata": {},
124
+ "source": [
125
+ "### 2.3 ๋ฐ์ดํ„ฐ์…‹ ์˜ˆ์‹œ"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "95b66ad0-c0ab-4be4-8214-ad02f1b8ebc6",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "dataset['train'][0]"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "745507f8-dda1-4f98-8814-0543af75401c",
141
+ "metadata": {},
142
+ "source": [
143
+ "# 3. Gemma ๋ชจ๋ธ์˜ ํ•œ๊ตญ์–ด ์š”์•ฝ ํ…Œ์ŠคํŠธ"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "7a1be307-f676-4f54-8c7a-894abadfe3be",
149
+ "metadata": {},
150
+ "source": [
151
+ "### 3.1 ๋ชจ๋ธ ๋กœ๋“œ"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "249d5ac1-78ed-48b3-a67a-402a45bc962c",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "BASE_MODEL = \"google/gemma-2b-it\"\n",
162
+ "\n",
163
+ "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map={\"\":0})\n",
164
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "id": "80ddcf5b-eaef-4852-9b9c-83799a08cc3e",
170
+ "metadata": {},
171
+ "source": [
172
+ "### 3.2 Gemma-it์˜ ํ”„๋กฌํ”„ํŠธ ํ˜•์‹"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "42076cb8-3f57-476f-8fe9-2e454bbe4235",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "doc = dataset['train']['document'][0]"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "b2f19d96-8aad-425c-9c4c-7f6420bd7849",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=512)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "id": "7dc8d3da-6060-4203-9346-953d8adfb680",
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "messages = [\n",
203
+ " {\n",
204
+ " \"role\": \"user\",\n",
205
+ " \"content\": \"๋‹ค์Œ ๊ธ€์„ ์š”์•ฝํ•ด์ฃผ์„ธ์š” :\\n\\n{}\".format(doc)\n",
206
+ " }\n",
207
+ "]\n",
208
+ "prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "9fa04590-01f2-4358-a68c-1eba8eeb5d3c",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "prompt"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "markdown",
223
+ "id": "666223ea-2308-4126-a56c-a57fcec65390",
224
+ "metadata": {},
225
+ "source": [
226
+ "### 3.3 Gemma-it ์ถ”๋ก "
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "id": "a61247af-ce20-47cb-ae80-5a3e40d299f1",
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "outputs = pipe(\n",
237
+ " prompt,\n",
238
+ " do_sample=True,\n",
239
+ " temperature=0.2,\n",
240
+ " top_k=50,\n",
241
+ " top_p=0.95,\n",
242
+ " add_special_tokens=True\n",
243
+ ")"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "df721816-9d14-4890-bc7f-a441b5c02481",
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "print(outputs[0][\"generated_text\"][len(prompt):])"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "id": "187a1bfb-b47c-448e-8957-86c00cc1df02",
259
+ "metadata": {},
260
+ "source": [
261
+ "# 4. Gemma ํŒŒ์ธํŠœ๋‹"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "markdown",
266
+ "id": "cc7b19a9-5a04-4d67-8004-de31fe0897a7",
267
+ "metadata": {},
268
+ "source": [
269
+ "#### ์ฃผ์˜: Colab GPU ๋ฉ”๋ชจ๋ฆฌ ํ•œ๊ณ„๋กœ ์ด์ „์žฅ ์ถ”๋ก ์—์„œ ์‚ฌ์šฉํ–ˆ๋˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋น„์›Œ ์ค˜์•ผ ํŒŒ์ธํŠœ๋‹์„ ์ง„ํ–‰ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. <br>ย notebook ๋Ÿฐํƒ€์ž„ ์„ธ์…˜์„ ์žฌ์‹œ์ž‘ ํ•œ ํ›„ 1๋ฒˆ๊ณผ 2๋ฒˆ์˜ 2.1 ํ•ญ๋ชฉ๊นŒ์ง€ ๋‹ค์‹œ ์‹คํ–‰ํ•˜์—ฌ ๋กœ๋“œ ํ•œ ํ›„ ์•„๋ž˜ ๊ณผ์ •์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "91bfe441-991f-4bb8-b9a3-a1d2e9fc509c",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "!nvidia-smi"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "markdown",
284
+ "id": "0a886413-a19c-4966-9e07-ca8cdb23aa16",
285
+ "metadata": {},
286
+ "source": [
287
+ "### 4.1 ํ•™์Šต์šฉ ํ”„๋กฌํ”„ํŠธ ์กฐ์ •"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "id": "a9e4cc4b-a094-4035-906e-3edface3a099",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "def generate_prompt(example):\n",
298
+ " prompt_list = []\n",
299
+ " for i in range(len(example['document'])):\n",
300
+ " prompt_list.append(r\"\"\"<bos><start_of_turn>user\n",
301
+ "๋‹ค์Œ ๊ธ€์„ ์š”์•ฝํ•ด์ฃผ์„ธ์š”:\n",
302
+ "\n",
303
+ "{}<end_of_turn>\n",
304
+ "<start_of_turn>model\n",
305
+ "{}<end_of_turn><eos>\"\"\".format(example['document'][i], example['summary'][i]))\n",
306
+ " return prompt_list"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "c45ab1ee-8146-4731-86ec-d673e9a67557",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "train_data = dataset['train']\n",
317
+ "print(generate_prompt(train_data[:1])[0])"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "id": "1849b4c0-16f3-44f3-bb67-7022f226ec05",
323
+ "metadata": {},
324
+ "source": [
325
+ "### 4.2 QLoRA ์„ค์ •"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "5c085b4b-a471-4c5a-afe3-81e8e0c37756",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "lora_config = LoraConfig(\n",
336
+ " r=6,\n",
337
+ " target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
338
+ " task_type=\"CAUSAL_LM\",\n",
339
+ ")\n",
340
+ "\n",
341
+ "bnb_config = BitsAndBytesConfig(\n",
342
+ " load_in_4bit=True,\n",
343
+ " bnb_4bit_quant_type=\"nf4\",\n",
344
+ " bnb_4bit_compute_dtype=torch.float16\n",
345
+ ")"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": null,
351
+ "id": "e10bfd65-00f8-49b6-933c-a27ed4385373",
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "BASE_MODEL = \"google/gemma-2b-it\"\n",
356
+ "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map=\"auto\", quantization_config=bnb_config)\n",
357
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)\n",
358
+ "tokenizer.padding_side = 'right'"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "id": "90db62d4-05ef-41ad-ad7b-a9c734c1b67d",
364
+ "metadata": {},
365
+ "source": [
366
+ "### 4.3 Trainer ์‹คํ–‰"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "id": "335301f3-c127-44e8-af43-1999e1844681",
373
+ "metadata": {
374
+ "scrolled": true
375
+ },
376
+ "outputs": [],
377
+ "source": [
378
+ "trainer = SFTTrainer(\n",
379
+ " model=model,\n",
380
+ " train_dataset=train_data,\n",
381
+ " max_seq_length=512,\n",
382
+ " args=TrainingArguments(\n",
383
+ " output_dir=\"outputs\",\n",
384
+ "# num_train_epochs = 1,\n",
385
+ " max_steps=3000,\n",
386
+ " per_device_train_batch_size=1,\n",
387
+ " gradient_accumulation_steps=4,\n",
388
+ " optim=\"paged_adamw_8bit\",\n",
389
+ " warmup_steps=0.03,\n",
390
+ " learning_rate=2e-4,\n",
391
+ " fp16=True,\n",
392
+ " logging_steps=100,\n",
393
+ " push_to_hub=False,\n",
394
+ " report_to='none',\n",
395
+ " ),\n",
396
+ " peft_config=lora_config,\n",
397
+ " formatting_func=generate_prompt,\n",
398
+ ")"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "id": "82fd7e65-334d-4052-9ab5-3c8e71bf09a5",
405
+ "metadata": {
406
+ "scrolled": true
407
+ },
408
+ "outputs": [],
409
+ "source": [
410
+ "trainer.train()"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "markdown",
415
+ "id": "dca74e51-15ec-403a-90f1-4b7eeb2c723b",
416
+ "metadata": {},
417
+ "source": [
418
+ "### 4.4 Finetuned Model ์ €์žฅ"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "f2bba87d-d95c-4a57-9eb1-c02d81ad7bfb",
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": [
428
+ "ADAPTER_MODEL = \"lora_adapter\"\n",
429
+ "\n",
430
+ "trainer.model.save_pretrained(ADAPTER_MODEL)"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": null,
436
+ "id": "6a9fcda0-1d7a-4443-9b1c-7d45490daafb",
437
+ "metadata": {},
438
+ "outputs": [],
439
+ "source": [
440
+ "!ls -alh lora_adapter"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "execution_count": null,
446
+ "id": "a9a2a6d7-ece4-472a-981f-fb6599d1d307",
447
+ "metadata": {},
448
+ "outputs": [],
449
+ "source": [
450
+ "model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)\n",
451
+ "model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)\n",
452
+ "\n",
453
+ "model = model.merge_and_unload()\n",
454
+ "model.save_pretrained('gemma-2b-it-sum-ko')"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "id": "1a764bbc-069d-400c-bca4-09e799bf0fb0",
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "!ls -alh ./gemma-2b-it-sum-ko"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "markdown",
469
+ "id": "84f2c237-71f4-47c2-bad4-181dadb6cc98",
470
+ "metadata": {},
471
+ "source": [
472
+ "# 5. Gemma ํ•œ๊ตญ์–ด ์š”์•ฝ ๋ชจ๋ธ ์ถ”๋ก "
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "markdown",
477
+ "id": "8587dfc7-cf7c-4072-a8f7-6ceb1e90a532",
478
+ "metadata": {},
479
+ "source": [
480
+ "#### ์ฃผ์˜: ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ Colab GPU ๋ฉ”๋ชจ๋ฆฌ ํ•œ๊ณ„๋กœ ํ•™์Šต ์‹œ ์‚ฌ์šฉํ–ˆ๋˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋น„์›Œ ์ค˜์•ผ ํŒŒ์ธํŠœ๋‹์„ ์ง„ํ–‰ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. <br>ย notebook ๋Ÿฐํƒ€์ž„ ์„ธ์…˜์„ ์žฌ์‹œ์ž‘ ํ•œ ํ›„ 1๋ฒˆ๊ณผ 2๋ฒˆ์˜ 2.1 ํ•ญ๋ชฉ๊นŒ์ง€ ๋‹ค์‹œ ์‹คํ–‰ํ•˜์—ฌ ๋กœ๋“œ ํ•œ ํ›„ ์•„๋ž˜ ๊ณผ์ •์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "id": "906ed4dd-270f-4000-84de-ede6885c0be5",
487
+ "metadata": {},
488
+ "outputs": [],
489
+ "source": [
490
+ "!nvidia-smi"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "markdown",
495
+ "id": "78399236-63b5-41af-9cee-a7233e23a9db",
496
+ "metadata": {},
497
+ "source": [
498
+ "### 5.1 Fine-tuned ๋ชจ๋ธ ๋กœ๋“œ"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": null,
504
+ "id": "76d5ba97-91ca-48c3-b9a2-ba9bea6d7b09",
505
+ "metadata": {},
506
+ "outputs": [],
507
+ "source": [
508
+ "BASE_MODEL = \"google/gemma-2b-it\"\n",
509
+ "FINETUNE_MODEL = \"./gemma-2b-it-sum-ko\"\n",
510
+ "\n",
511
+ "finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={\"\":0})\n",
512
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "markdown",
517
+ "id": "5c34718c-ce52-4d68-ac8c-c18b6483b15b",
518
+ "metadata": {},
519
+ "source": [
520
+ "### 5.2 Fine-tuned ๋ชจ๋ธ ์ถ”๋ก "
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "id": "a0f0fc82-abaf-49df-9254-7ccee2e74d96",
527
+ "metadata": {
528
+ "scrolled": true
529
+ },
530
+ "outputs": [],
531
+ "source": [
532
+ "pipe_finetuned = pipeline(\"text-generation\", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": null,
538
+ "id": "2f915638-d859-446f-bc78-070650421ece",
539
+ "metadata": {},
540
+ "outputs": [],
541
+ "source": [
542
+ "doc = dataset['test']['document'][10]"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "id": "396788e7-4b80-46d7-980f-38fcb892a94f",
549
+ "metadata": {},
550
+ "outputs": [],
551
+ "source": [
552
+ "messages = [\n",
553
+ " {\n",
554
+ " \"role\": \"user\",\n",
555
+ " \"content\": \"๋‹ค์Œ ๊ธ€์„ ์š”์•ฝํ•ด์ฃผ์„ธ์š”:\\n\\n{}\".format(doc)\n",
556
+ " }\n",
557
+ "]\n",
558
+ "prompt = pipe_finetuned.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": null,
564
+ "id": "03f1f711-0ba7-4087-8317-b0e7f4246aee",
565
+ "metadata": {},
566
+ "outputs": [],
567
+ "source": [
568
+ "outputs = pipe_finetuned(\n",
569
+ " prompt,\n",
570
+ " do_sample=True,\n",
571
+ " temperature=0.2,\n",
572
+ " top_k=50,\n",
573
+ " top_p=0.95,\n",
574
+ " add_special_tokens=True\n",
575
+ ")\n",
576
+ "print(outputs[0][\"generated_text\"][len(prompt):])"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": null,
582
+ "id": "73cb6b26-f1d1-4b7b-ba16-1ff62689fb94",
583
+ "metadata": {},
584
+ "outputs": [],
585
+ "source": []
586
+ }
587
+ ],
588
+ "metadata": {
589
+ "kernelspec": {
590
+ "display_name": "Python 3 (ipykernel)",
591
+ "language": "python",
592
+ "name": "python3"
593
+ },
594
+ "language_info": {
595
+ "codemirror_mode": {
596
+ "name": "ipython",
597
+ "version": 3
598
+ },
599
+ "file_extension": ".py",
600
+ "mimetype": "text/x-python",
601
+ "name": "python",
602
+ "nbconvert_exporter": "python",
603
+ "pygments_lexer": "ipython3",
604
+ "version": "3.8.10"
605
+ }
606
+ },
607
+ "nbformat": 4,
608
+ "nbformat_minor": 5
609
+ }