boris commited on
Commit
acc1a4a
1 Parent(s): b4bb5b9

feat: allow weight decay

Browse files
Files changed (1) hide show
  1. tools/train/train.py +8 -0
tools/train/train.py CHANGED
@@ -331,6 +331,9 @@ class TrainingArguments:
331
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
332
  },
333
  )
 
 
 
334
  beta1: float = field(
335
  default=0.9,
336
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
@@ -489,6 +492,8 @@ class TrainingArguments:
489
  "adam",
490
  "adafactor",
491
  ], f"Selected optimizer not supported: {self.optim}"
 
 
492
  assert self.graft_type in [
493
  "rmsprop_normalized",
494
  "rmsprop",
@@ -844,6 +849,7 @@ def main():
844
  beta2=training_args.beta2,
845
  diagonal_epsilon=1e-10,
846
  matrix_epsilon=1e-6,
 
847
  start_preconditioning_step=max(
848
  training_args.preconditioning_compute_steps + 1, 101
849
  ),
@@ -892,6 +898,7 @@ def main():
892
  b1=training_args.beta1,
893
  b2=training_args.beta2,
894
  eps=training_args.adam_epsilon,
 
895
  )
896
  optimizer = {k: optimizer for k in split_params(params_shape)}
897
 
@@ -901,6 +908,7 @@ def main():
901
  optimizer = optax.adafactor(
902
  learning_rate=learning_rate_fn,
903
  clipping_threshold=training_args.max_grad_norm,
 
904
  )
905
  optimizer = {k: optimizer for k in split_params(params_shape)}
906
 
 
331
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
332
  },
333
  )
334
+ weight_decay: float = field(
335
+ default=0.0, metadata={"help": "Weight decay applied to parameters."}
336
+ )
337
  beta1: float = field(
338
  default=0.9,
339
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
 
492
  "adam",
493
  "adafactor",
494
  ], f"Selected optimizer not supported: {self.optim}"
495
+ if self.optim == "adafactor" and self.weight_decay == 0:
496
+ self.weight_decay = None
497
  assert self.graft_type in [
498
  "rmsprop_normalized",
499
  "rmsprop",
 
849
  beta2=training_args.beta2,
850
  diagonal_epsilon=1e-10,
851
  matrix_epsilon=1e-6,
852
+ weight_decay=training_args.weight_decay,
853
  start_preconditioning_step=max(
854
  training_args.preconditioning_compute_steps + 1, 101
855
  ),
 
898
  b1=training_args.beta1,
899
  b2=training_args.beta2,
900
  eps=training_args.adam_epsilon,
901
+ weight_decay=training_args.weight_decay,
902
  )
903
  optimizer = {k: optimizer for k in split_params(params_shape)}
904
 
 
908
  optimizer = optax.adafactor(
909
  learning_rate=learning_rate_fn,
910
  clipping_threshold=training_args.max_grad_norm,
911
+ weight_decay_rate=training_args.weight_decay,
912
  )
913
  optimizer = {k: optimizer for k in split_params(params_shape)}
914