zetavg commited on
Commit
68255ee
·
1 Parent(s): 0537112
llama_lora/lib/finetune.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
  from typing import Any, List
4
 
5
  import json
@@ -54,16 +55,38 @@ def train(
54
  # wandb params
55
  wandb_api_key = None,
56
  wandb_project: str = "",
 
57
  wandb_run_name: str = "",
 
58
  wandb_watch: str = "false", # options: false | gradients | all
59
  wandb_log_model: str = "true", # options: false | true
60
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if wandb_api_key:
62
  os.environ["WANDB_API_KEY"] = wandb_api_key
63
- if wandb_project:
64
- os.environ["WANDB_PROJECT"] = wandb_project
65
- if wandb_run_name:
66
- os.environ["WANDB_RUN_NAME"] = wandb_run_name
 
 
67
  if wandb_watch:
68
  os.environ["WANDB_WATCH"] = wandb_watch
69
  if wandb_log_model:
@@ -73,6 +96,18 @@ def train(
73
  )
74
  if use_wandb:
75
  os.environ['WANDB_MODE'] = "online"
 
 
 
 
 
 
 
 
 
 
 
 
76
  else:
77
  os.environ['WANDB_MODE'] = "disabled"
78
 
@@ -243,24 +278,8 @@ def train(
243
  os.makedirs(output_dir)
244
  with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
245
  json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
246
- with open(os.path.join(output_dir, "finetune_params.json"), 'w') as finetune_params_json_file:
247
- finetune_params = {
248
- 'micro_batch_size': micro_batch_size,
249
- 'gradient_accumulation_steps': gradient_accumulation_steps,
250
- 'num_train_epochs': num_train_epochs,
251
- 'learning_rate': learning_rate,
252
- 'cutoff_len': cutoff_len,
253
- 'lora_r': lora_r,
254
- 'lora_alpha': lora_alpha,
255
- 'lora_dropout': lora_dropout,
256
- 'lora_target_modules': lora_target_modules,
257
- 'train_on_inputs': train_on_inputs,
258
- 'group_by_length': group_by_length,
259
- 'save_steps': save_steps,
260
- 'save_total_limit': save_total_limit,
261
- 'logging_steps': logging_steps,
262
- }
263
- json.dump(finetune_params, finetune_params_json_file, indent=2)
264
 
265
  # Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
266
  # if train_data:
 
1
  import os
2
  import sys
3
+ import importlib
4
  from typing import Any, List
5
 
6
  import json
 
55
  # wandb params
56
  wandb_api_key = None,
57
  wandb_project: str = "",
58
+ wandb_group = None,
59
  wandb_run_name: str = "",
60
+ wandb_tags: List[str] = [],
61
  wandb_watch: str = "false", # options: false | gradients | all
62
  wandb_log_model: str = "true", # options: false | true
63
  ):
64
+ # for logging
65
+ finetune_args = {
66
+ 'micro_batch_size': micro_batch_size,
67
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
68
+ 'num_train_epochs': num_train_epochs,
69
+ 'learning_rate': learning_rate,
70
+ 'cutoff_len': cutoff_len,
71
+ 'lora_r': lora_r,
72
+ 'lora_alpha': lora_alpha,
73
+ 'lora_dropout': lora_dropout,
74
+ 'lora_target_modules': lora_target_modules,
75
+ 'train_on_inputs': train_on_inputs,
76
+ 'group_by_length': group_by_length,
77
+ 'save_steps': save_steps,
78
+ 'save_total_limit': save_total_limit,
79
+ 'logging_steps': logging_steps,
80
+ }
81
+
82
  if wandb_api_key:
83
  os.environ["WANDB_API_KEY"] = wandb_api_key
84
+
85
+ # wandb: WARNING Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to https://wandb.me/wandb-init.
86
+ # if wandb_project:
87
+ # os.environ["WANDB_PROJECT"] = wandb_project
88
+ # if wandb_run_name:
89
+ # os.environ["WANDB_RUN_NAME"] = wandb_run_name
90
  if wandb_watch:
91
  os.environ["WANDB_WATCH"] = wandb_watch
92
  if wandb_log_model:
 
96
  )
97
  if use_wandb:
98
  os.environ['WANDB_MODE'] = "online"
99
+ wandb = importlib.import_module("wandb")
100
+ wandb.init(
101
+ project=wandb_project,
102
+ resume="auto",
103
+ group=wandb_group,
104
+ name=wandb_run_name,
105
+ tags=wandb_tags,
106
+ reinit=True,
107
+ magic=True,
108
+ config={'finetune_args': finetune_args},
109
+ # id=None # used for resuming
110
+ )
111
  else:
112
  os.environ['WANDB_MODE'] = "disabled"
113
 
 
278
  os.makedirs(output_dir)
279
  with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
280
  json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
281
+ with open(os.path.join(output_dir, "finetune_args.json"), 'w') as finetune_args_json_file:
282
+ json.dump(finetune_args, finetune_args_json_file, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
285
  # if train_data:
llama_lora/ui/finetune_ui.py CHANGED
@@ -415,6 +415,12 @@ Train data (first 10):
415
  if not should_training_progress_track_tqdm:
416
  progress(0, desc="Train starting...")
417
 
 
 
 
 
 
 
418
  train_output = Global.train_fn(
419
  base_model, # base_model
420
  tokenizer, # tokenizer
@@ -440,7 +446,9 @@ Train data (first 10):
440
  training_callbacks, # callbacks
441
  Global.wandb_api_key, # wandb_api_key
442
  Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
443
- model_name # wandb_run_name
 
 
444
  )
445
 
446
  logs_str = "\n".join([json.dumps(log)
 
415
  if not should_training_progress_track_tqdm:
416
  progress(0, desc="Train starting...")
417
 
418
+ wandb_group = template
419
+ wandb_tags = [f"template:{template}"]
420
+ if load_dataset_from == "Data Dir" and dataset_from_data_dir:
421
+ wandb_group += f"/{dataset_from_data_dir}"
422
+ wandb_tags.append(f"dataset:{dataset_from_data_dir}")
423
+
424
  train_output = Global.train_fn(
425
  base_model, # base_model
426
  tokenizer, # tokenizer
 
446
  training_callbacks, # callbacks
447
  Global.wandb_api_key, # wandb_api_key
448
  Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
449
+ wandb_group, # wandb_group
450
+ model_name, # wandb_run_name
451
+ wandb_tags # wandb_tags
452
  )
453
 
454
  logs_str = "\n".join([json.dumps(log)