ryefoxlime commited on
Commit
dbebf53
·
1 Parent(s): e0e167f

Updated Hyperparams and dataset

Browse files
Files changed (2) hide show
  1. Gemma2_2B/finetune.ipynb +147 -345
  2. Gemma2_2B/hyperparams.yaml +13 -7
Gemma2_2B/finetune.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -15,87 +15,84 @@
15
  "login(token=os.getenv(\"HUGGINGFACE_TOKEN\"))"
16
  ]
17
  },
 
 
 
 
 
 
 
 
18
  {
19
  "cell_type": "code",
20
- "execution_count": 10,
21
  "metadata": {},
22
- "outputs": [
23
- {
24
- "data": {
25
- "application/vnd.jupyter.widget-view+json": {
26
- "model_id": "a39e6120cbea4462999cfa5f887a8015",
27
- "version_major": 2,
28
- "version_minor": 0
29
- },
30
- "text/plain": [
31
- "README.md: 0%| | 0.00/288 [00:00<?, ?B/s]"
32
- ]
33
- },
34
- "metadata": {},
35
- "output_type": "display_data"
36
- },
37
- {
38
- "name": "stderr",
39
- "output_type": "stream",
40
- "text": [
41
- "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Nitin Kausik Remella\\.cache\\huggingface\\hub\\datasets--ai-bites--databricks-mini. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
42
- "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
43
- " warnings.warn(message)\n"
44
- ]
45
- },
46
- {
47
- "data": {
48
- "application/vnd.jupyter.widget-view+json": {
49
- "model_id": "de15e48751c34c36b5d02c2449380d06",
50
- "version_major": 2,
51
- "version_minor": 0
52
- },
53
- "text/plain": [
54
- "dolly-mini-train.jsonl: 0%| | 0.00/5.24M [00:00<?, ?B/s]"
55
- ]
56
- },
57
- "metadata": {},
58
- "output_type": "display_data"
59
- },
60
- {
61
- "data": {
62
- "application/vnd.jupyter.widget-view+json": {
63
- "model_id": "d4094fd4af084a77a5bc3904b5db4197",
64
- "version_major": 2,
65
- "version_minor": 0
66
- },
67
- "text/plain": [
68
- "Generating train split: 0%| | 0/10544 [00:00<?, ? examples/s]"
69
- ]
70
- },
71
- "metadata": {},
72
- "output_type": "display_data"
73
- },
74
- {
75
- "data": {
76
- "text/plain": [
77
- "Dataset({\n",
78
- " features: ['text'],\n",
79
- " num_rows: 1000\n",
80
- "})"
81
- ]
82
- },
83
- "execution_count": 10,
84
- "metadata": {},
85
- "output_type": "execute_result"
86
- }
87
- ],
88
  "source": [
89
  "from datasets import load_dataset\n",
90
- "dataset_name = \"ai-bites/databricks-mini\"\n",
91
- "dataset = load_dataset(dataset_name, split=\"train[0:1000]\", cache_dir=\".cache/\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  "\n",
93
- "dataset"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  ]
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": 11,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  "metadata": {},
100
  "outputs": [],
101
  "source": [
@@ -109,12 +106,12 @@
109
  " logging,\n",
110
  ")\n",
111
  "from peft import LoraConfig, PeftModel\n",
112
- "from trl import SFTTrainer"
113
  ]
114
  },
115
  {
116
  "cell_type": "code",
117
- "execution_count": 30,
118
  "metadata": {},
119
  "outputs": [],
120
  "source": [
@@ -125,7 +122,7 @@
125
  },
126
  {
127
  "cell_type": "code",
128
- "execution_count": 31,
129
  "metadata": {},
130
  "outputs": [],
131
  "source": [
@@ -141,17 +138,9 @@
141
  },
142
  {
143
  "cell_type": "code",
144
- "execution_count": 32,
145
  "metadata": {},
146
- "outputs": [
147
- {
148
- "name": "stdout",
149
- "output_type": "stream",
150
- "text": [
151
- "Setting BF16 to True\n"
152
- ]
153
- }
154
- ],
155
  "source": [
156
  "# Check GPU compatibility with bfloat16\n",
157
  "if compute_dtype == torch.float16 and hyperparams['use_4bit']:\n",
@@ -165,24 +154,9 @@
165
  },
166
  {
167
  "cell_type": "code",
168
- "execution_count": 33,
169
  "metadata": {},
170
- "outputs": [
171
- {
172
- "data": {
173
- "application/vnd.jupyter.widget-view+json": {
174
- "model_id": "9ab84ef6c43249de9726940a78f2717f",
175
- "version_major": 2,
176
- "version_minor": 0
177
- },
178
- "text/plain": [
179
- "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
180
- ]
181
- },
182
- "metadata": {},
183
- "output_type": "display_data"
184
- }
185
- ],
186
  "source": [
187
  "model = AutoModelForCausalLM.from_pretrained(\n",
188
  " hyperparams['model_name'],\n",
@@ -201,7 +175,7 @@
201
  },
202
  {
203
  "cell_type": "code",
204
- "execution_count": 34,
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
@@ -218,275 +192,103 @@
218
  },
219
  {
220
  "cell_type": "code",
221
- "execution_count": 39,
222
  "metadata": {},
223
- "outputs": [
224
- {
225
- "data": {
226
- "text/plain": [
227
- "TrainingArguments(\n",
228
- "_n_gpu=1,\n",
229
- "accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},\n",
230
- "adafactor=False,\n",
231
- "adam_beta1=0.9,\n",
232
- "adam_beta2=0.999,\n",
233
- "adam_epsilon=1e-08,\n",
234
- "auto_find_batch_size=False,\n",
235
- "average_tokens_across_devices=False,\n",
236
- "batch_eval_metrics=False,\n",
237
- "bf16=True,\n",
238
- "bf16_full_eval=False,\n",
239
- "data_seed=None,\n",
240
- "dataloader_drop_last=False,\n",
241
- "dataloader_num_workers=0,\n",
242
- "dataloader_persistent_workers=False,\n",
243
- "dataloader_pin_memory=True,\n",
244
- "dataloader_prefetch_factor=None,\n",
245
- "ddp_backend=None,\n",
246
- "ddp_broadcast_buffers=None,\n",
247
- "ddp_bucket_cap_mb=None,\n",
248
- "ddp_find_unused_parameters=None,\n",
249
- "ddp_timeout=1800,\n",
250
- "debug=[],\n",
251
- "deepspeed=None,\n",
252
- "disable_tqdm=False,\n",
253
- "dispatch_batches=None,\n",
254
- "do_eval=False,\n",
255
- "do_predict=False,\n",
256
- "do_train=False,\n",
257
- "eval_accumulation_steps=None,\n",
258
- "eval_delay=0,\n",
259
- "eval_do_concat_batches=True,\n",
260
- "eval_on_start=False,\n",
261
- "eval_steps=None,\n",
262
- "eval_strategy=IntervalStrategy.NO,\n",
263
- "eval_use_gather_object=False,\n",
264
- "evaluation_strategy=None,\n",
265
- "fp16=False,\n",
266
- "fp16_backend=auto,\n",
267
- "fp16_full_eval=False,\n",
268
- "fp16_opt_level=O1,\n",
269
- "fsdp=[],\n",
270
- "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},\n",
271
- "fsdp_min_num_params=0,\n",
272
- "fsdp_transformer_layer_cls_to_wrap=None,\n",
273
- "full_determinism=False,\n",
274
- "gradient_accumulation_steps=1,\n",
275
- "gradient_checkpointing=False,\n",
276
- "gradient_checkpointing_kwargs=None,\n",
277
- "greater_is_better=None,\n",
278
- "group_by_length=True,\n",
279
- "half_precision_backend=auto,\n",
280
- "hub_always_push=False,\n",
281
- "hub_model_id=None,\n",
282
- "hub_private_repo=False,\n",
283
- "hub_strategy=HubStrategy.EVERY_SAVE,\n",
284
- "hub_token=<HUB_TOKEN>,\n",
285
- "ignore_data_skip=False,\n",
286
- "include_for_metrics=[],\n",
287
- "include_inputs_for_metrics=False,\n",
288
- "include_num_input_tokens_seen=False,\n",
289
- "include_tokens_per_second=False,\n",
290
- "jit_mode_eval=False,\n",
291
- "label_names=None,\n",
292
- "label_smoothing_factor=0.0,\n",
293
- "learning_rate=0.0002,\n",
294
- "length_column_name=length,\n",
295
- "load_best_model_at_end=False,\n",
296
- "local_rank=0,\n",
297
- "log_level=passive,\n",
298
- "log_level_replica=warning,\n",
299
- "log_on_each_node=True,\n",
300
- "logging_dir=./results\\runs\\Nov15_13-14-10_FutureGadgetLab,\n",
301
- "logging_first_step=False,\n",
302
- "logging_nan_inf_filter=True,\n",
303
- "logging_steps=25,\n",
304
- "logging_strategy=IntervalStrategy.STEPS,\n",
305
- "lr_scheduler_kwargs={},\n",
306
- "lr_scheduler_type=SchedulerType.CONSTANT,\n",
307
- "max_grad_norm=0.3,\n",
308
- "max_steps=-1,\n",
309
- "metric_for_best_model=None,\n",
310
- "mp_parameters=,\n",
311
- "neftune_noise_alpha=None,\n",
312
- "no_cuda=False,\n",
313
- "num_train_epochs=1,\n",
314
- "optim=OptimizerNames.PAGED_ADAMW,\n",
315
- "optim_args=None,\n",
316
- "optim_target_modules=None,\n",
317
- "output_dir=./results,\n",
318
- "overwrite_output_dir=False,\n",
319
- "past_index=-1,\n",
320
- "per_device_eval_batch_size=8,\n",
321
- "per_device_train_batch_size=2,\n",
322
- "prediction_loss_only=False,\n",
323
- "push_to_hub=False,\n",
324
- "push_to_hub_model_id=None,\n",
325
- "push_to_hub_organization=None,\n",
326
- "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
327
- "ray_scope=last,\n",
328
- "remove_unused_columns=True,\n",
329
- "report_to=['tensorboard'],\n",
330
- "restore_callback_states_from_checkpoint=False,\n",
331
- "resume_from_checkpoint=None,\n",
332
- "run_name=./results,\n",
333
- "save_on_each_node=False,\n",
334
- "save_only_model=False,\n",
335
- "save_safetensors=True,\n",
336
- "save_steps=25,\n",
337
- "save_strategy=IntervalStrategy.STEPS,\n",
338
- "save_total_limit=None,\n",
339
- "seed=42,\n",
340
- "skip_memory_metrics=True,\n",
341
- "split_batches=None,\n",
342
- "tf32=None,\n",
343
- "torch_compile=False,\n",
344
- "torch_compile_backend=None,\n",
345
- "torch_compile_mode=None,\n",
346
- "torch_empty_cache_steps=None,\n",
347
- "torchdynamo=None,\n",
348
- "tpu_metrics_debug=False,\n",
349
- "tpu_num_cores=None,\n",
350
- "use_cpu=False,\n",
351
- "use_ipex=False,\n",
352
- "use_legacy_prediction_loop=False,\n",
353
- "use_liger_kernel=False,\n",
354
- "use_mps_device=False,\n",
355
- "warmup_ratio=0.03,\n",
356
- "warmup_steps=0,\n",
357
- "weight_decay=0.001,\n",
358
- ")"
359
- ]
360
- },
361
- "execution_count": 39,
362
- "metadata": {},
363
- "output_type": "execute_result"
364
- }
365
- ],
366
  "source": [
 
 
 
 
 
 
 
 
 
 
367
  "# Set training parameters\n",
368
  "training_arguments = TrainingArguments(\n",
369
- " output_dir=hyperparams['output_dir'],\n",
370
- " num_train_epochs=hyperparams['num_train_epochs'],\n",
371
- " per_device_train_batch_size=hyperparams['per_device_train_batch_size'],\n",
372
- " gradient_accumulation_steps=hyperparams['gradient_accumulation_steps'],\n",
373
- " optim=hyperparams['optimizer'],\n",
374
- " save_steps=hyperparams['save_steps'],\n",
375
- " logging_steps=hyperparams['logging_steps'],\n",
376
- " learning_rate=float(hyperparams['learning_rate']),\n",
377
- " weight_decay=hyperparams['weight_decay'],\n",
378
- " fp16=hyperparams['fp16'],\n",
379
- " bf16=hyperparams['bf16'],\n",
380
- " max_grad_norm=hyperparams['max_grad_norm'],\n",
381
- " max_steps=hyperparams['max_steps'],\n",
382
- " warmup_ratio=hyperparams['warmup_ratio'],\n",
383
- " group_by_length=hyperparams['group_by_length'],\n",
384
- " lr_scheduler_type=hyperparams['lr_scheduler_type'],\n",
385
- " report_to=\"tensorboard\",\n",
 
 
 
 
 
 
386
  ")\n",
387
  "training_arguments"
388
  ]
389
  },
390
  {
391
  "cell_type": "code",
392
- "execution_count": 40,
393
  "metadata": {},
394
- "outputs": [
395
- {
396
- "name": "stderr",
397
- "output_type": "stream",
398
- "text": [
399
- "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\utils\\_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length, packing. Will not be supported from version '0.13.0'.\n",
400
- "\n",
401
- "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n",
402
- " warnings.warn(message, FutureWarning)\n",
403
- "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:212: UserWarning: You passed a `packing` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
404
- " warnings.warn(\n",
405
- "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
406
- " warnings.warn(\n",
407
- "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
408
- " warnings.warn(\n"
409
- ]
410
- }
411
- ],
412
  "source": [
413
  "trainer = SFTTrainer(\n",
414
  " model=model,\n",
415
- " train_dataset=dataset,\n",
 
416
  " peft_config=peft_config,\n",
417
  " dataset_text_field=\"text\",\n",
418
  " # formatting_func=format_prompts_fn,\n",
419
- " max_seq_length=hyperparams['max_seq_length'],\n",
420
  " tokenizer=tokenizer,\n",
421
  " args=training_arguments,\n",
422
- " packing=hyperparams['packing'],\n",
423
  ")"
424
  ]
425
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  {
427
  "cell_type": "code",
428
  "execution_count": null,
429
  "metadata": {},
430
- "outputs": [
431
- {
432
- "data": {
433
- "application/vnd.jupyter.widget-view+json": {
434
- "model_id": "0033f5bb31a7416facfd8a3fd3bd5ad1",
435
- "version_major": 2,
436
- "version_minor": 0
437
- },
438
- "text/plain": [
439
- " 0%| | 0/1340 [00:00<?, ?it/s]"
440
- ]
441
- },
442
- "metadata": {},
443
- "output_type": "display_data"
444
- },
445
- {
446
- "name": "stdout",
447
- "output_type": "stream",
448
- "text": [
449
- "{'loss': 3.8879, 'grad_norm': 18.030195236206055, 'learning_rate': 0.0002, 'epoch': 0.02}\n",
450
- "{'loss': 2.9569, 'grad_norm': 9.667036056518555, 'learning_rate': 0.0002, 'epoch': 0.04}\n",
451
- "{'loss': 2.6361, 'grad_norm': 9.089476585388184, 'learning_rate': 0.0002, 'epoch': 0.06}\n",
452
- "{'loss': 2.9523, 'grad_norm': 6.053662300109863, 'learning_rate': 0.0002, 'epoch': 0.07}\n",
453
- "{'loss': 2.8543, 'grad_norm': 7.764152526855469, 'learning_rate': 0.0002, 'epoch': 0.09}\n",
454
- "{'loss': 2.8802, 'grad_norm': 6.539248466491699, 'learning_rate': 0.0002, 'epoch': 0.11}\n",
455
- "{'loss': 2.7047, 'grad_norm': 5.485109329223633, 'learning_rate': 0.0002, 'epoch': 0.13}\n",
456
- "{'loss': 2.6576, 'grad_norm': 9.22624397277832, 'learning_rate': 0.0002, 'epoch': 0.15}\n",
457
- "{'loss': 2.7756, 'grad_norm': 6.477100372314453, 'learning_rate': 0.0002, 'epoch': 0.17}\n",
458
- "{'loss': 2.7012, 'grad_norm': 5.891603946685791, 'learning_rate': 0.0002, 'epoch': 0.19}\n",
459
- "{'loss': 2.5026, 'grad_norm': 5.75968599319458, 'learning_rate': 0.0002, 'epoch': 0.21}\n",
460
- "{'loss': 2.8085, 'grad_norm': 7.938610076904297, 'learning_rate': 0.0002, 'epoch': 0.22}\n",
461
- "{'loss': 2.5286, 'grad_norm': 5.600504398345947, 'learning_rate': 0.0002, 'epoch': 0.24}\n",
462
- "{'loss': 2.5495, 'grad_norm': 6.746212005615234, 'learning_rate': 0.0002, 'epoch': 0.26}\n",
463
- "{'loss': 2.7405, 'grad_norm': 3.8923749923706055, 'learning_rate': 0.0002, 'epoch': 0.28}\n",
464
- "{'loss': 2.5657, 'grad_norm': 5.949460506439209, 'learning_rate': 0.0002, 'epoch': 0.3}\n",
465
- "{'loss': 2.6052, 'grad_norm': 5.733223915100098, 'learning_rate': 0.0002, 'epoch': 0.32}\n",
466
- "{'loss': 2.673, 'grad_norm': 6.0587310791015625, 'learning_rate': 0.0002, 'epoch': 0.34}\n",
467
- "{'loss': 2.4631, 'grad_norm': 4.734077453613281, 'learning_rate': 0.0002, 'epoch': 0.35}\n",
468
- "{'loss': 2.7288, 'grad_norm': 6.7847700119018555, 'learning_rate': 0.0002, 'epoch': 0.37}\n",
469
- "{'loss': 2.7797, 'grad_norm': 5.118943214416504, 'learning_rate': 0.0002, 'epoch': 0.39}\n",
470
- "{'loss': 2.8644, 'grad_norm': 5.4167304039001465, 'learning_rate': 0.0002, 'epoch': 0.41}\n",
471
- "{'loss': 2.5779, 'grad_norm': 6.73247766494751, 'learning_rate': 0.0002, 'epoch': 0.43}\n",
472
- "{'loss': 2.6459, 'grad_norm': 4.644010066986084, 'learning_rate': 0.0002, 'epoch': 0.45}\n",
473
- "{'loss': 2.5321, 'grad_norm': 6.347738265991211, 'learning_rate': 0.0002, 'epoch': 0.47}\n",
474
- "{'loss': 2.6865, 'grad_norm': 5.185911655426025, 'learning_rate': 0.0002, 'epoch': 0.49}\n",
475
- "{'loss': 2.4668, 'grad_norm': 5.355742454528809, 'learning_rate': 0.0002, 'epoch': 0.5}\n",
476
- "{'loss': 2.8465, 'grad_norm': 5.4434380531311035, 'learning_rate': 0.0002, 'epoch': 0.52}\n",
477
- "{'loss': 2.7376, 'grad_norm': 4.8459882736206055, 'learning_rate': 0.0002, 'epoch': 0.54}\n",
478
- "{'loss': 2.5205, 'grad_norm': 5.886116981506348, 'learning_rate': 0.0002, 'epoch': 0.56}\n",
479
- "{'loss': 2.7473, 'grad_norm': 4.946981906890869, 'learning_rate': 0.0002, 'epoch': 0.58}\n",
480
- "{'loss': 2.6824, 'grad_norm': 6.349016189575195, 'learning_rate': 0.0002, 'epoch': 0.6}\n",
481
- "{'loss': 2.6485, 'grad_norm': 5.024953365325928, 'learning_rate': 0.0002, 'epoch': 0.62}\n",
482
- "{'loss': 2.7172, 'grad_norm': 5.583380222320557, 'learning_rate': 0.0002, 'epoch': 0.63}\n",
483
- "{'loss': 2.5879, 'grad_norm': 6.582890033721924, 'learning_rate': 0.0002, 'epoch': 0.65}\n"
484
- ]
485
- }
486
- ],
487
  "source": [
488
- "trainer.train()\n",
489
- "trainer.model.save_pretrained(hyperparams['new_model_name'])"
490
  ]
491
  }
492
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
15
  "login(token=os.getenv(\"HUGGINGFACE_TOKEN\"))"
16
  ]
17
  },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "# Dataset\n",
23
+ "Modifyify the dataset to fit the Gemma 2 prompt format"
24
+ ]
25
+ },
26
  {
27
  "cell_type": "code",
28
+ "execution_count": null,
29
  "metadata": {},
30
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  "source": [
32
  "from datasets import load_dataset\n",
33
+ "dataset_name = \"nbertagnolli/counsel-chat\"\n",
34
+ "dataset = load_dataset(dataset_name, split=\"train\",cache_dir=\".cache/\")\n",
35
+ "\n",
36
+ "# Print the first example from the dataset\n",
37
+ "print(dataset[0])\n",
38
+ "print(f\"\\n {dataset}\")"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "gemma_prompt = \"\"\" \n",
48
+ "### System:\n",
49
+ "You are a Therapist Assistant, an LLM fine-tuned on Gemma 2 model by Google.\n",
50
+ "You provide safe and responsible support to users while encouraging them to visit a mental health professional if needed. \n",
51
+ "You are committed to promoting wellness, understanding, and support. Your responses should be clear, concise, and evidence-based, while maintaining a friendly and approachable tone.\n",
52
  "\n",
53
+ "### User:\n",
54
+ "{}\n",
55
+ "\n",
56
+ "### Response:\n",
57
+ "{}\n",
58
+ "\"\"\"\n",
59
+ "\n",
60
+ "def format_prompts_func(example):\n",
61
+ " \"\"\"Formats questionText and answerText into the Gemma 2 prompt format.\"\"\"\n",
62
+ " question_texts = example[\"questionText\"]\n",
63
+ " answer_texts = example[\"answerText\"]\n",
64
+ " texts = []\n",
65
+ " for q, a in zip(question_texts, answer_texts):\n",
66
+ " text = gemma_prompt.format(q, a)\n",
67
+ " texts.append(text)\n",
68
+ "\n",
69
+ " return {\"text\": texts}\n",
70
+ "pass\n",
71
+ "# Apply the formatting function to the dataset\n",
72
+ "formatted_dataset = dataset.map(format_prompts_func, batched=True)\n",
73
+ "print(formatted_dataset['text'][0])\n"
74
  ]
75
  },
76
  {
77
  "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "dataset = formatted_dataset.train_test_split(test_size=0.2, seed=42)\n",
83
+ "print(dataset['train'].shape, dataset['test'].shape)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "# Fine tuning hyperpterparameters"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
  "metadata": {},
97
  "outputs": [],
98
  "source": [
 
106
  " logging,\n",
107
  ")\n",
108
  "from peft import LoraConfig, PeftModel\n",
109
+ "from trl import SFTTrainer\n"
110
  ]
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": null,
115
  "metadata": {},
116
  "outputs": [],
117
  "source": [
 
122
  },
123
  {
124
  "cell_type": "code",
125
+ "execution_count": null,
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
 
138
  },
139
  {
140
  "cell_type": "code",
141
+ "execution_count": null,
142
  "metadata": {},
143
+ "outputs": [],
 
 
 
 
 
 
 
 
144
  "source": [
145
  "# Check GPU compatibility with bfloat16\n",
146
  "if compute_dtype == torch.float16 and hyperparams['use_4bit']:\n",
 
154
  },
155
  {
156
  "cell_type": "code",
157
+ "execution_count": null,
158
  "metadata": {},
159
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  "source": [
161
  "model = AutoModelForCausalLM.from_pretrained(\n",
162
  " hyperparams['model_name'],\n",
 
175
  },
176
  {
177
  "cell_type": "code",
178
+ "execution_count": null,
179
  "metadata": {},
180
  "outputs": [],
181
  "source": [
 
192
  },
193
  {
194
  "cell_type": "code",
195
+ "execution_count": null,
196
  "metadata": {},
197
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  "source": [
199
+ "import wandb\n",
200
+ "import time\n",
201
+ "wandb.login(key=os.getenv(\"WANDB_API_KEY\"))\n",
202
+ "run = wandb.init(\n",
203
+ " project='TADBot',\n",
204
+ " job_type=\"training\",\n",
205
+ " anonymous=\"allow\"\n",
206
+ ")\n",
207
+ "run_name = f\"{hyperparams['model_name']}--health-bot-{int(time.time())}\"\n",
208
+ "\n",
209
  "# Set training parameters\n",
210
  "training_arguments = TrainingArguments(\n",
211
+ " output_dir=f\"./outputs/{run_name}\",\n",
212
+ " per_device_train_batch_size=hyperparams[\"per_device_train_batch_size\"],\n",
213
+ " per_device_eval_batch_size=hyperparams[\"per_device_eval_batch_size\"],\n",
214
+ " gradient_accumulation_steps=hyperparams[\"gradient_accumulation_steps\"],\n",
215
+ " optim=hyperparams[\"optimizer\"],\n",
216
+ " num_train_epochs=hyperparams[\"num_train_epochs\"],\n",
217
+ " eval_steps=hyperparams[\"eval_steps\"],\n",
218
+ " eval_strategy=hyperparams[\"eval_strategy\"],\n",
219
+ " save_steps=hyperparams[\"save_steps\"],\n",
220
+ " logging_steps=hyperparams[\"logging_steps\"],\n",
221
+ " logging_strategy=hyperparams[\"logging_strategy\"],\n",
222
+ " warmup_steps=hyperparams[\"warmup_steps\"],\n",
223
+ " learning_rate=float(hyperparams[\"learning_rate\"]),\n",
224
+ " weight_decay=hyperparams[\"weight_decay\"],\n",
225
+ " fp16=hyperparams[\"fp16\"],\n",
226
+ " bf16=hyperparams[\"bf16\"],\n",
227
+ " max_grad_norm=hyperparams[\"max_grad_norm\"],\n",
228
+ " max_steps=hyperparams[\"max_steps\"],\n",
229
+ " group_by_length=hyperparams[\"group_by_length\"],\n",
230
+ " lr_scheduler_type=hyperparams[\"lr_scheduler_type\"],\n",
231
+ " logging_dir=f\"./outputs/{run_name}/logs\",\n",
232
+ " report_to=\"wandb\",\n",
233
+ " run_name=run_name\n",
234
  ")\n",
235
  "training_arguments"
236
  ]
237
  },
238
  {
239
  "cell_type": "code",
240
+ "execution_count": null,
241
  "metadata": {},
242
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  "source": [
244
  "trainer = SFTTrainer(\n",
245
  " model=model,\n",
246
+ " train_dataset=dataset[\"train\"],\n",
247
+ " eval_dataset=dataset['test'],\n",
248
  " peft_config=peft_config,\n",
249
  " dataset_text_field=\"text\",\n",
250
  " # formatting_func=format_prompts_fn,\n",
251
+ " max_seq_length=hyperparams[\"max_seq_length\"],\n",
252
  " tokenizer=tokenizer,\n",
253
  " args=training_arguments,\n",
254
+ " packing=hyperparams[\"packing\"],\n",
255
  ")"
256
  ]
257
  },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {},
261
+ "source": [
262
+ "# Fine tuning the model"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "model.config.use_cache = False\n",
272
+ "trainer.train()"
273
+ ]
274
+ },
275
  {
276
  "cell_type": "code",
277
  "execution_count": null,
278
  "metadata": {},
279
+ "outputs": [],
280
+ "source": [
281
+ "wandb.finish()\n",
282
+ "model.config.use_cache = True\n",
283
+ "# Save the model\n",
284
+ "trainer.model.save_pretrained(hyperparams[\"new_model_name\"])"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  "source": [
291
+ "%tensorboard --logdir Gemma2_2B\\\\results\\\\runs"
 
292
  ]
293
  }
294
  ],
Gemma2_2B/hyperparams.yaml CHANGED
@@ -1,34 +1,40 @@
1
  model_name: "google/gemma-2-2b-it"
2
  new_model_name: "gemma-2-2b-ft"
3
 
 
4
  lora_r: 64
5
  lora_alpha: 16
6
  lora_dropout: 0.1
7
 
 
8
  use_4bit: True
9
  bnb_4bit_compute_dtype: "float16"
10
  bnb_4bit_quant_type: "nf4"
11
  use_nested_quant: False
12
 
13
- output_dir: "./results"
14
- num_train_epochs: 2
15
  fp16: False
16
  bf16: False
17
  per_device_train_batch_size: 2
18
  per_device_eval_batch_size: 2
19
- gradient_accumulation_steps: 1
20
  gradient_checkpointing: True
 
 
21
  max_grad_norm: 0.3
22
- learning_rate: 2e-3
23
  weight_decay: 0.001
24
  optimizer: "paged_adamw_32bit"
25
  lr_scheduler_type: "constant"
26
  max_steps: -1
27
- warmup_ratio: 0.03
28
  group_by_length: True
29
- save_steps: 25
30
- logging_steps: 25
 
31
 
 
32
  max_seq_length: 128
33
  packing: True
34
  device_map: "auto"
 
1
  model_name: "google/gemma-2-2b-it"
2
  new_model_name: "gemma-2-2b-ft"
3
 
4
+ # LoRA Paraments
5
  lora_r: 64
6
  lora_alpha: 16
7
  lora_dropout: 0.1
8
 
9
+ #bitsandbytes parameters
10
  use_4bit: True
11
  bnb_4bit_compute_dtype: "float16"
12
  bnb_4bit_quant_type: "nf4"
13
  use_nested_quant: False
14
 
15
+ #Training Arguments
16
+ num_train_epochs: 1
17
  fp16: False
18
  bf16: False
19
  per_device_train_batch_size: 2
20
  per_device_eval_batch_size: 2
21
+ gradient_accumulation_steps: 2
22
  gradient_checkpointing: True
23
+ eval_strategy: "steps"
24
+ eval_steps: 0.2
25
  max_grad_norm: 0.3
26
+ learning_rate: 2e-4
27
  weight_decay: 0.001
28
  optimizer: "paged_adamw_32bit"
29
  lr_scheduler_type: "constant"
30
  max_steps: -1
31
+ warmup_steps: 5
32
  group_by_length: True
33
+ save_steps: 50
34
+ logging_steps: 50
35
+ logging_strategy: "steps"
36
 
37
+ #SFT Arguments
38
  max_seq_length: 128
39
  packing: True
40
  device_map: "auto"