zetavg commited on
Commit
38fb491
1 Parent(s): 00263ef

support resume_from_checkpoint

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -33,7 +33,7 @@ def train(
33
  num_train_epochs: int = 3,
34
  learning_rate: float = 3e-4,
35
  cutoff_len: int = 256,
36
- val_set_size: int = 2000, # TODO: use percentage
37
  # lora hyperparams
38
  lora_r: int = 8,
39
  lora_alpha: int = 16,
@@ -46,7 +46,7 @@ def train(
46
  train_on_inputs: bool = True, # if False, masks out inputs in loss
47
  group_by_length: bool = False, # faster, but produces an odd training loss curve
48
  # either training checkpoint or final adapter
49
- resume_from_checkpoint: str = None,
50
  save_steps: int = 200,
51
  save_total_limit: int = 3,
52
  logging_steps: int = 10,
@@ -68,6 +68,7 @@ def train(
68
  'num_train_epochs': num_train_epochs,
69
  'learning_rate': learning_rate,
70
  'cutoff_len': cutoff_len,
 
71
  'lora_r': lora_r,
72
  'lora_alpha': lora_alpha,
73
  'lora_dropout': lora_dropout,
@@ -78,7 +79,12 @@ def train(
78
  'save_total_limit': save_total_limit,
79
  'logging_steps': logging_steps,
80
  }
 
 
 
 
81
 
 
82
  if wandb_api_key:
83
  os.environ["WANDB_API_KEY"] = wandb_api_key
84
 
@@ -220,7 +226,7 @@ def train(
220
  adapters_weights = torch.load(checkpoint_name)
221
  model = set_peft_model_state_dict(model, adapters_weights)
222
  else:
223
- print(f"Checkpoint {checkpoint_name} not found")
224
 
225
  # Be more transparent about the % of trainable params.
226
  model.print_trainable_parameters()
@@ -315,4 +321,7 @@ def train(
315
  with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
316
  json.dump(train_output, train_output_json_file, indent=2)
317
 
 
 
 
318
  return train_output
 
33
  num_train_epochs: int = 3,
34
  learning_rate: float = 3e-4,
35
  cutoff_len: int = 256,
36
+ val_set_size: int = 2000,
37
  # lora hyperparams
38
  lora_r: int = 8,
39
  lora_alpha: int = 16,
 
46
  train_on_inputs: bool = True, # if False, masks out inputs in loss
47
  group_by_length: bool = False, # faster, but produces an odd training loss curve
48
  # either training checkpoint or final adapter
49
+ resume_from_checkpoint = None,
50
  save_steps: int = 200,
51
  save_total_limit: int = 3,
52
  logging_steps: int = 10,
 
68
  'num_train_epochs': num_train_epochs,
69
  'learning_rate': learning_rate,
70
  'cutoff_len': cutoff_len,
71
+ 'val_set_size': val_set_size,
72
  'lora_r': lora_r,
73
  'lora_alpha': lora_alpha,
74
  'lora_dropout': lora_dropout,
 
79
  'save_total_limit': save_total_limit,
80
  'logging_steps': logging_steps,
81
  }
82
+ if val_set_size and val_set_size > 0:
83
+ finetune_args['val_set_size'] = val_set_size
84
+ if resume_from_checkpoint:
85
+ finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
86
 
87
+ wandb = None
88
  if wandb_api_key:
89
  os.environ["WANDB_API_KEY"] = wandb_api_key
90
 
 
226
  adapters_weights = torch.load(checkpoint_name)
227
  model = set_peft_model_state_dict(model, adapters_weights)
228
  else:
229
+ raise ValueError(f"Checkpoint {checkpoint_name} not found")
230
 
231
  # Be more transparent about the % of trainable params.
232
  model.print_trainable_parameters()
 
321
  with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
322
  json.dump(train_output, train_output_json_file, indent=2)
323
 
324
+ if use_wandb and wandb:
325
+ wandb.finish()
326
+
327
  return train_output
llama_lora/ui/finetune_ui.py CHANGED
@@ -306,6 +306,17 @@ def do_train(
306
  ):
307
  try:
308
  base_model_name = Global.base_model_name
 
 
 
 
 
 
 
 
 
 
 
309
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
310
  if os.path.exists(output_dir):
311
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
@@ -376,6 +387,8 @@ Train options: {json.dumps({
376
  'lora_dropout': lora_dropout,
377
  'lora_target_modules': lora_target_modules,
378
  'model_name': model_name,
 
 
379
  }, indent=2)}
380
 
381
  Train data (first 10):
@@ -386,7 +399,7 @@ Train data (first 10):
386
  return message
387
 
388
  if not should_training_progress_track_tqdm:
389
- progress(0, desc="Preparing model for training...")
390
 
391
  log_history = []
392
 
@@ -461,6 +474,10 @@ Train data (first 10):
461
  # 'lora_dropout': lora_dropout,
462
  # 'lora_target_modules': lora_target_modules,
463
  }
 
 
 
 
464
  json.dump(info, info_json_file, indent=2)
465
 
466
  if not should_training_progress_track_tqdm:
@@ -490,7 +507,7 @@ Train data (first 10):
490
  lora_target_modules, # lora_target_modules
491
  train_on_inputs, # train_on_inputs
492
  False, # group_by_length
493
- None, # resume_from_checkpoint
494
  save_steps, # save_steps
495
  save_total_limit, # save_total_limit
496
  logging_steps, # logging_steps
@@ -582,6 +599,8 @@ def handle_load_params_from_model(
582
  cutoff_len = value
583
  elif key == "evaluate_data_count":
584
  evaluate_data_count = value
 
 
585
  elif key == "micro_batch_size":
586
  micro_batch_size = value
587
  elif key == "gradient_accumulation_steps":
@@ -610,6 +629,8 @@ def handle_load_params_from_model(
610
  logging_steps = value
611
  elif key == "group_by_length":
612
  pass
 
 
613
  else:
614
  unknown_keys.append(key)
615
  except Exception as e:
 
306
  ):
307
  try:
308
  base_model_name = Global.base_model_name
309
+
310
+ resume_from_checkpoint = None
311
+ if continue_from_model == "-" or continue_from_model == "None":
312
+ continue_from_model = None
313
+ if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
314
+ continue_from_checkpoint = None
315
+ if continue_from_model:
316
+ resume_from_checkpoint = os.path.join(Global.data_dir, "lora_models", continue_from_model)
317
+ if continue_from_checkpoint:
318
+ resume_from_checkpoint = os.path.join(resume_from_checkpoint, continue_from_checkpoint)
319
+
320
  output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
321
  if os.path.exists(output_dir):
322
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
 
387
  'lora_dropout': lora_dropout,
388
  'lora_target_modules': lora_target_modules,
389
  'model_name': model_name,
390
+ 'continue_from_model': continue_from_model,
391
+ 'continue_from_checkpoint': continue_from_checkpoint,
392
  }, indent=2)}
393
 
394
  Train data (first 10):
 
399
  return message
400
 
401
  if not should_training_progress_track_tqdm:
402
+ progress(0, desc=f"Preparing model {base_model_name} for training...")
403
 
404
  log_history = []
405
 
 
474
  # 'lora_dropout': lora_dropout,
475
  # 'lora_target_modules': lora_target_modules,
476
  }
477
+ if continue_from_model:
478
+ info['continued_from_model'] = continue_from_model
479
+ if continue_from_checkpoint:
480
+ info['continued_from_checkpoint'] = continue_from_checkpoint
481
  json.dump(info, info_json_file, indent=2)
482
 
483
  if not should_training_progress_track_tqdm:
 
507
  lora_target_modules, # lora_target_modules
508
  train_on_inputs, # train_on_inputs
509
  False, # group_by_length
510
+ resume_from_checkpoint, # resume_from_checkpoint
511
  save_steps, # save_steps
512
  save_total_limit, # save_total_limit
513
  logging_steps, # logging_steps
 
599
  cutoff_len = value
600
  elif key == "evaluate_data_count":
601
  evaluate_data_count = value
602
+ elif key == "val_set_size":
603
+ evaluate_data_count = value
604
  elif key == "micro_batch_size":
605
  micro_batch_size = value
606
  elif key == "gradient_accumulation_steps":
 
629
  logging_steps = value
630
  elif key == "group_by_length":
631
  pass
632
+ elif key == "resume_from_checkpoint":
633
+ pass
634
  else:
635
  unknown_keys.append(key)
636
  except Exception as e: