feat(train): allow nesterov momentum
Browse files- tools/train/train.py +5 -1
tools/train/train.py
CHANGED
@@ -365,6 +365,10 @@ class TrainingArguments:
|
|
365 |
"help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
|
366 |
},
|
367 |
)
|
|
|
|
|
|
|
|
|
368 |
optim_quantized: bool = field(
|
369 |
default=False,
|
370 |
metadata={
|
@@ -857,7 +861,7 @@ def main():
|
|
857 |
statistics_compute_steps=1,
|
858 |
best_effort_shape_interpretation=True,
|
859 |
graft_type=graft_type,
|
860 |
-
nesterov=
|
861 |
exponent_override=0,
|
862 |
statistics_partition_spec=statistics_partition_spec,
|
863 |
preconditioner_partition_spec=PartitionSpec(
|
|
|
365 |
"help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
|
366 |
},
|
367 |
)
|
368 |
+
nesterov: bool = field(
|
369 |
+
default=False,
|
370 |
+
metadata={"help": "Use Nesterov momentum for Distributed Shampoo."},
|
371 |
+
)
|
372 |
optim_quantized: bool = field(
|
373 |
default=False,
|
374 |
metadata={
|
|
|
861 |
statistics_compute_steps=1,
|
862 |
best_effort_shape_interpretation=True,
|
863 |
graft_type=graft_type,
|
864 |
+
nesterov=training_args.nesterov,
|
865 |
exponent_override=0,
|
866 |
statistics_partition_spec=statistics_partition_spec,
|
867 |
preconditioner_partition_spec=PartitionSpec(
|