boris commited on
Commit
032f623
·
1 Parent(s): 8a9e367

feat(train): handle distributed_shampoo in pjit

Browse files
Files changed (1) hide show
  1. tools/train/train.py +40 -36
tools/train/train.py CHANGED
@@ -25,7 +25,7 @@ import sys
25
  import time
26
  from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
28
- from typing import Callable, Optional
29
 
30
  import datasets
31
  import jax
@@ -36,7 +36,7 @@ import transformers
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
- from flax.core.frozen_dict import freeze, unfreeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
@@ -523,6 +523,12 @@ def main():
523
  use_fast=True,
524
  )
525
 
 
 
 
 
 
 
526
  # Preprocessing the datasets.
527
  # We need to normalize and tokenize inputs and targets.
528
 
@@ -620,6 +626,13 @@ def main():
620
  precision=jax.lax.Precision.HIGHEST,
621
  best_effort_memory_usage_reduction=training_args.optim_quantized,
622
  )
 
 
 
 
 
 
 
623
 
624
  elif training_args.optim == "adam":
625
  optimizer = optax.adamw(
@@ -636,43 +649,40 @@ def main():
636
  clipping_threshold=training_args.max_grad_norm,
637
  )
638
 
639
- # get PartitionSpec for model params
640
- param_spec = set_partitions(model.params)
641
-
642
  # get PartitionSpec for optimizer state
643
  def get_opt_state_spec_and_shape(param_spec):
644
- if training_args.optim == "adam":
645
  # get opt_state shape without actual init
646
  opt_state_shape = jax.eval_shape(optimizer.init, model.params)
647
 
648
- def _opt_state_spec_per_leaf(x):
649
- if isinstance(x, dict):
650
- # variables with same structure as params
651
- return param_spec
652
- else:
653
- # other variables such as count
654
- return None
655
-
656
- opt_state_spec = jax.tree_map(
657
- _opt_state_spec_per_leaf,
658
- opt_state_shape,
659
- # return None spec for empty elements
660
- is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
661
- )
 
 
662
 
663
- elif training_args.optim == "adafactor":
664
- # factorized state must be replicated (rank different than params)
665
- opt_state_spec = None
666
 
667
  elif training_args.optim == "distributed_shampoo":
668
- # memory efficient in distributed_shampoo, fake init
669
- _opt_state = optimizer.init(model.params)
670
- opt_state_spec = _opt_state.pspec_fn(
671
  params=model.params,
672
- params_partition_spec=unfreeze(param_spec),
673
  partition_spec_for_statistics=PartitionSpec(None, "batch", None),
674
  )
675
- opt_state_shape = _opt_state.shape_and_dtype_fn(model.params)
676
  else:
677
  raise NotImplementedError
678
  return opt_state_spec, opt_state_shape
@@ -714,18 +724,12 @@ def main():
714
  in_axis_resources=(param_spec,),
715
  out_axis_resources=state_spec,
716
  donate_argnums=(0,),
717
- )(freeze(model.params))
718
 
719
  else:
720
  # restore opt_state
721
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
722
  opt_state = from_bytes(opt_state_shape, f.read())
723
- # need to freeze dict for pjit
724
- opt_state = jax.tree_map(
725
- lambda x: freeze(x) if isinstance(x, dict) else x,
726
- opt_state,
727
- is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
728
- )
729
 
730
  # restore other attributes
731
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
@@ -746,7 +750,7 @@ def main():
746
  in_axis_resources=(param_spec, opt_state_spec),
747
  out_axis_resources=state_spec,
748
  donate_argnums=(0, 1),
749
- )(freeze(model.params), opt_state)
750
 
751
  # remove opt_state from CPU
752
  del opt_state
 
25
  import time
26
  from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
28
+ from typing import Any, Callable, NamedTuple, Optional
29
 
30
  import datasets
31
  import jax
 
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
+ from flax.core.frozen_dict import FrozenDict, freeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
 
523
  use_fast=True,
524
  )
525
 
526
+ # get PartitionSpec for model params (required to be a dict)
527
+ param_spec = set_partitions(model.params)
528
+
529
+ # convert params to frozen dict
530
+ model._params = freeze(model.params)
531
+
532
  # Preprocessing the datasets.
533
  # We need to normalize and tokenize inputs and targets.
534
 
 
626
  precision=jax.lax.Precision.HIGHEST,
627
  best_effort_memory_usage_reduction=training_args.optim_quantized,
628
  )
629
+ # get the real optimizer and helper functions
630
+ update_fn = optimizer.update
631
+ optimizer = optimizer.init(model.params)
632
+ opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
633
+ optimizer.pspec_fn, optimizer.shape_and_dtype_fn
634
+ )
635
+ optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
636
 
637
  elif training_args.optim == "adam":
638
  optimizer = optax.adamw(
 
649
  clipping_threshold=training_args.max_grad_norm,
650
  )
651
 
 
 
 
652
  # get PartitionSpec for optimizer state
653
  def get_opt_state_spec_and_shape(param_spec):
654
+ if training_args.optim in ["adam", "adafactor"]:
655
  # get opt_state shape without actual init
656
  opt_state_shape = jax.eval_shape(optimizer.init, model.params)
657
 
658
+ if training_args.optim == "adam":
659
+
660
+ def _opt_state_spec_per_leaf(x):
661
+ if isinstance(x, FrozenDict):
662
+ # variables with same structure as params
663
+ return param_spec
664
+ else:
665
+ # other variables such as count
666
+ return None
667
+
668
+ opt_state_spec = jax.tree_map(
669
+ _opt_state_spec_per_leaf,
670
+ opt_state_shape,
671
+ # return None spec for empty elements
672
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
673
+ )
674
 
675
+ elif training_args.optim == "adafactor":
676
+ # factorized state must be replicated (rank different than params)
677
+ opt_state_spec = None
678
 
679
  elif training_args.optim == "distributed_shampoo":
680
+ opt_state_spec = opt_fn.pspec_fn(
 
 
681
  params=model.params,
682
+ params_partition_spec=param_spec,
683
  partition_spec_for_statistics=PartitionSpec(None, "batch", None),
684
  )
685
+ opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
686
  else:
687
  raise NotImplementedError
688
  return opt_state_spec, opt_state_shape
 
724
  in_axis_resources=(param_spec,),
725
  out_axis_resources=state_spec,
726
  donate_argnums=(0,),
727
+ )(model.params)
728
 
729
  else:
730
  # restore opt_state
731
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
732
  opt_state = from_bytes(opt_state_shape, f.read())
 
 
 
 
 
 
733
 
734
  # restore other attributes
735
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
 
750
  in_axis_resources=(param_spec, opt_state_spec),
751
  out_axis_resources=state_spec,
752
  donate_argnums=(0, 1),
753
+ )(model.params, opt_state)
754
 
755
  # remove opt_state from CPU
756
  del opt_state