boris commited on
Commit
d08bf8d
1 Parent(s): acc1a4a

feat(train): allow nesterov momentum

Browse files
Files changed (1) hide show
  1. tools/train/train.py +5 -1
tools/train/train.py CHANGED
@@ -365,6 +365,10 @@ class TrainingArguments:
365
  "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
366
  },
367
  )
 
 
 
 
368
  optim_quantized: bool = field(
369
  default=False,
370
  metadata={
@@ -857,7 +861,7 @@ def main():
857
  statistics_compute_steps=1,
858
  best_effort_shape_interpretation=True,
859
  graft_type=graft_type,
860
- nesterov=False,
861
  exponent_override=0,
862
  statistics_partition_spec=statistics_partition_spec,
863
  preconditioner_partition_spec=PartitionSpec(
 
365
  "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
366
  },
367
  )
368
+ nesterov: bool = field(
369
+ default=False,
370
+ metadata={"help": "Use Nesterov momentum for Distributed Shampoo."},
371
+ )
372
  optim_quantized: bool = field(
373
  default=False,
374
  metadata={
 
861
  statistics_compute_steps=1,
862
  best_effort_shape_interpretation=True,
863
  graft_type=graft_type,
864
+ nesterov=training_args.nesterov,
865
  exponent_override=0,
866
  statistics_partition_spec=statistics_partition_spec,
867
  preconditioner_partition_spec=PartitionSpec(