feat: allow weight decay
Browse files- tools/train/train.py +8 -0
tools/train/train.py
CHANGED
@@ -331,6 +331,9 @@ class TrainingArguments:
|
|
331 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
332 |
},
|
333 |
)
|
|
|
|
|
|
|
334 |
beta1: float = field(
|
335 |
default=0.9,
|
336 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
@@ -489,6 +492,8 @@ class TrainingArguments:
|
|
489 |
"adam",
|
490 |
"adafactor",
|
491 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
492 |
assert self.graft_type in [
|
493 |
"rmsprop_normalized",
|
494 |
"rmsprop",
|
@@ -844,6 +849,7 @@ def main():
|
|
844 |
beta2=training_args.beta2,
|
845 |
diagonal_epsilon=1e-10,
|
846 |
matrix_epsilon=1e-6,
|
|
|
847 |
start_preconditioning_step=max(
|
848 |
training_args.preconditioning_compute_steps + 1, 101
|
849 |
),
|
@@ -892,6 +898,7 @@ def main():
|
|
892 |
b1=training_args.beta1,
|
893 |
b2=training_args.beta2,
|
894 |
eps=training_args.adam_epsilon,
|
|
|
895 |
)
|
896 |
optimizer = {k: optimizer for k in split_params(params_shape)}
|
897 |
|
@@ -901,6 +908,7 @@ def main():
|
|
901 |
optimizer = optax.adafactor(
|
902 |
learning_rate=learning_rate_fn,
|
903 |
clipping_threshold=training_args.max_grad_norm,
|
|
|
904 |
)
|
905 |
optimizer = {k: optimizer for k in split_params(params_shape)}
|
906 |
|
|
|
331 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
332 |
},
|
333 |
)
|
334 |
+
weight_decay: float = field(
|
335 |
+
default=0.0, metadata={"help": "Weight decay applied to parameters."}
|
336 |
+
)
|
337 |
beta1: float = field(
|
338 |
default=0.9,
|
339 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
|
492 |
"adam",
|
493 |
"adafactor",
|
494 |
], f"Selected optimizer not supported: {self.optim}"
|
495 |
+
if self.optim == "adafactor" and self.weight_decay == 0:
|
496 |
+
self.weight_decay = None
|
497 |
assert self.graft_type in [
|
498 |
"rmsprop_normalized",
|
499 |
"rmsprop",
|
|
|
849 |
beta2=training_args.beta2,
|
850 |
diagonal_epsilon=1e-10,
|
851 |
matrix_epsilon=1e-6,
|
852 |
+
weight_decay=training_args.weight_decay,
|
853 |
start_preconditioning_step=max(
|
854 |
training_args.preconditioning_compute_steps + 1, 101
|
855 |
),
|
|
|
898 |
b1=training_args.beta1,
|
899 |
b2=training_args.beta2,
|
900 |
eps=training_args.adam_epsilon,
|
901 |
+
weight_decay=training_args.weight_decay,
|
902 |
)
|
903 |
optimizer = {k: optimizer for k in split_params(params_shape)}
|
904 |
|
|
|
908 |
optimizer = optax.adafactor(
|
909 |
learning_rate=learning_rate_fn,
|
910 |
clipping_threshold=training_args.max_grad_norm,
|
911 |
+
weight_decay_rate=training_args.weight_decay,
|
912 |
)
|
913 |
optimizer = {k: optimizer for k in split_params(params_shape)}
|
914 |
|