boris commited on
Commit
a2bf605
·
1 Parent(s): c91ceb7

feat(train): cleanup args

Browse files
Files changed (1) hide show
  1. tools/train/train.py +21 -17
tools/train/train.py CHANGED
@@ -199,8 +199,11 @@ class TrainingArguments:
199
  per_device_train_batch_size: int = field(
200
  default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
201
  )
202
- per_device_eval_batch_size: int = field(
203
- default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
 
 
 
204
  )
205
 
206
  gradient_accumulation_steps: int = field(
@@ -252,6 +255,13 @@ class TrainingArguments:
252
  },
253
  )
254
 
 
 
 
 
 
 
 
255
  lr_decay: str = field(
256
  default=None,
257
  metadata={
@@ -277,13 +287,6 @@ class TrainingArguments:
277
  },
278
  )
279
 
280
- num_train_epochs: int = field(
281
- default=3, metadata={"help": "Total number of training epochs to perform."}
282
- )
283
- warmup_steps: int = field(
284
- default=0, metadata={"help": "Linear warmup over warmup_steps."}
285
- )
286
-
287
  logging_steps: int = field(
288
  default=40, metadata={"help": "Log every X updates steps."}
289
  )
@@ -334,6 +337,11 @@ class TrainingArguments:
334
  "adam",
335
  "adafactor",
336
  ], f"Selected optimizer not supported: {self.optim}"
 
 
 
 
 
337
  if (
338
  os.path.exists(self.output_dir)
339
  and os.listdir(self.output_dir)
@@ -623,9 +631,7 @@ def main():
623
  beta2=training_args.beta2,
624
  diagonal_epsilon=1e-10,
625
  matrix_epsilon=1e-8,
626
- weight_decay=training_args.weight_decay
627
- if training_args.weight_decay is not None
628
- else 0.0,
629
  start_preconditioning_step=training_args.warmup_steps,
630
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
631
  statistics_compute_steps=1,
@@ -648,9 +654,7 @@ def main():
648
  b1=training_args.beta1,
649
  b2=training_args.beta2,
650
  eps=training_args.adam_epsilon,
651
- weight_decay=training_args.weight_decay
652
- if training_args.weight_decay is not None
653
- else 0.0,
654
  mask=decay_mask_fn,
655
  )
656
  elif training_args.optim == "adafactor":
@@ -749,8 +753,8 @@ def main():
749
  return metrics
750
 
751
  # Create parallel version of the train and eval step
752
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
753
- p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
754
 
755
  logger.info("***** Running training *****")
756
  logger.info(f" Num examples = {len_train_dataset}")
 
199
  per_device_train_batch_size: int = field(
200
  default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
201
  )
202
+ per_device_eval_batch_size: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "Batch size per GPU/TPU/CPU for evaluation. Same as training batch size if not set."
206
+ },
207
  )
208
 
209
  gradient_accumulation_steps: int = field(
 
255
  },
256
  )
257
 
258
+ num_train_epochs: int = field(
259
+ default=3, metadata={"help": "Total number of training epochs to perform."}
260
+ )
261
+
262
+ warmup_steps: int = field(
263
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
264
+ )
265
  lr_decay: str = field(
266
  default=None,
267
  metadata={
 
287
  },
288
  )
289
 
 
 
 
 
 
 
 
290
  logging_steps: int = field(
291
  default=40, metadata={"help": "Log every X updates steps."}
292
  )
 
337
  "adam",
338
  "adafactor",
339
  ], f"Selected optimizer not supported: {self.optim}"
340
+ if self.per_device_eval_batch_size is None:
341
+ self.per_device_eval_batch_size = self.per_device_train_batch_size
342
+ if self.weight_decay is None:
343
+ if self.optim in ["distributed_shampoo", "adam"]:
344
+ self.weight_decay = 0.0
345
  if (
346
  os.path.exists(self.output_dir)
347
  and os.listdir(self.output_dir)
 
631
  beta2=training_args.beta2,
632
  diagonal_epsilon=1e-10,
633
  matrix_epsilon=1e-8,
634
+ weight_decay=training_args.weight_decay,
 
 
635
  start_preconditioning_step=training_args.warmup_steps,
636
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
637
  statistics_compute_steps=1,
 
654
  b1=training_args.beta1,
655
  b2=training_args.beta2,
656
  eps=training_args.adam_epsilon,
657
+ weight_decay=training_args.weight_decay,
 
 
658
  mask=decay_mask_fn,
659
  )
660
  elif training_args.optim == "adafactor":
 
753
  return metrics
754
 
755
  # Create parallel version of the train and eval step
756
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
757
+ p_eval_step = jax.pmap(eval_step, "batch")
758
 
759
  logger.info("***** Running training *****")
760
  logger.info(f" Num examples = {len_train_dataset}")