boris commited on
Commit
c3e93df
unverified
1 Parent(s): 79a3849

feat: support LR offset (#174)

Browse files
Files changed (1) hide show
  1. tools/train/train.py +17 -16
tools/train/train.py CHANGED
@@ -119,7 +119,7 @@ class ModelArguments:
119
  ), "Restoring state only available with W&B artifact reference"
120
 
121
  def get_metadata(self):
122
- if self.restore_state:
123
  if jax.process_index() == 0:
124
  artifact = wandb.run.use_artifact(self.model_name_or_path)
125
  else:
@@ -413,11 +413,9 @@ class TrainingArguments:
413
  "help": "Whether to use staircase or continuous learning rate when using exponential decay."
414
  },
415
  )
416
- lr_resume_offset: bool = field(
417
- default=False,
418
- metadata={
419
- "help": "Whether to offset the learning rate function with current step when resuming a run."
420
- },
421
  )
422
  logging_steps: int = field(
423
  default=40, metadata={"help": "Log every X updates steps."}
@@ -796,14 +794,14 @@ def main():
796
  end_value=training_args.learning_rate,
797
  transition_steps=training_args.warmup_steps + 1, # ensure not 0
798
  )
799
- # offset step when resuming
800
  last_boundary = training_args.warmup_steps
801
- if model_metadata.get("step", 0) and training_args.lr_resume_offset:
 
802
  warmup_fn = optax.join_schedules(
803
  schedules=[optax.constant_schedule(0.0), warmup_fn],
804
- boundaries=[model_metadata["step"]],
805
  )
806
- last_boundary += model_metadata["step"]
807
  if training_args.lr_decay is None:
808
  return warmup_fn
809
  elif training_args.lr_decay == "linear":
@@ -1005,6 +1003,14 @@ def main():
1005
 
1006
  with mesh:
1007
  logger.info(" Creating state")
 
 
 
 
 
 
 
 
1008
  if not model_args.restore_state:
1009
 
1010
  def init_state(params):
@@ -1013,6 +1019,7 @@ def main():
1013
  tx=optimizer,
1014
  params=maybe_init_params(params),
1015
  dropout_rng=dropout_rng,
 
1016
  )
1017
 
1018
  state = pjit(
@@ -1028,12 +1035,6 @@ def main():
1028
  # load opt_state
1029
  opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
1030
 
1031
- # restore other attributes
1032
- attr_state = {
1033
- k: model_metadata[k]
1034
- for k in ["step", "epoch", "train_time", "train_samples"]
1035
- }
1036
-
1037
  def restore_state(params, opt_state):
1038
  return TrainState(
1039
  apply_fn=model.__call__,
 
119
  ), "Restoring state only available with W&B artifact reference"
120
 
121
  def get_metadata(self):
122
+ if ":" in self.model_name_or_path:
123
  if jax.process_index() == 0:
124
  artifact = wandb.run.use_artifact(self.model_name_or_path)
125
  else:
 
413
  "help": "Whether to use staircase or continuous learning rate when using exponential decay."
414
  },
415
  )
416
+ lr_offset: int = field(
417
+ default=0,
418
+ metadata={"help": "Number of steps to offset learning rate and keep it at 0."},
 
 
419
  )
420
  logging_steps: int = field(
421
  default=40, metadata={"help": "Log every X updates steps."}
 
794
  end_value=training_args.learning_rate,
795
  transition_steps=training_args.warmup_steps + 1, # ensure not 0
796
  )
 
797
  last_boundary = training_args.warmup_steps
798
+ # offset step when resuming
799
+ if training_args.lr_offset:
800
  warmup_fn = optax.join_schedules(
801
  schedules=[optax.constant_schedule(0.0), warmup_fn],
802
+ boundaries=[training_args.lr_offset],
803
  )
804
+ last_boundary += training_args.lr_offset
805
  if training_args.lr_decay is None:
806
  return warmup_fn
807
  elif training_args.lr_decay == "linear":
 
1003
 
1004
  with mesh:
1005
  logger.info(" Creating state")
1006
+
1007
+ # restore metadata
1008
+ attr_state = {}
1009
+ keys = ["train_time", "train_samples"]
1010
+ if model_args.restore_state:
1011
+ keys += ["step", "epoch"]
1012
+ attr_state = {k: v for k, v in model_metadata.items() if k in keys}
1013
+
1014
  if not model_args.restore_state:
1015
 
1016
  def init_state(params):
 
1019
  tx=optimizer,
1020
  params=maybe_init_params(params),
1021
  dropout_rng=dropout_rng,
1022
+ **attr_state,
1023
  )
1024
 
1025
  state = pjit(
 
1035
  # load opt_state
1036
  opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
1037
 
 
 
 
 
 
 
1038
  def restore_state(params, opt_state):
1039
  return TrainState(
1040
  apply_fn=model.__call__,