boris commited on
Commit
49597a2
·
1 Parent(s): 0081723

feat(train): progress on pjit

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +0 -2
  2. tools/train/train.py +34 -31
src/dalle_mini/data.py CHANGED
@@ -191,7 +191,6 @@ class Dataset:
191
  lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
192
  batch,
193
  )
194
- batch = shard(batch)
195
  yield batch
196
 
197
  def _dataloader_datasets_streaming(
@@ -232,7 +231,6 @@ class Dataset:
232
  ),
233
  batch,
234
  )
235
- batch = shard(batch)
236
  yield batch
237
  batch = {k: [] for k in keys}
238
  first_loop = False
 
191
  lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
192
  batch,
193
  )
 
194
  yield batch
195
 
196
  def _dataloader_datasets_streaming(
 
231
  ),
232
  batch,
233
  )
 
234
  yield batch
235
  batch = {k: [] for k in keys}
236
  first_loop = False
tools/train/train.py CHANGED
@@ -34,13 +34,11 @@ import numpy as np
34
  import optax
35
  import transformers
36
  from datasets import Dataset
37
- from distributed_shampoo import GraftingType, distributed_shampoo, pad_matrix
38
- from flax import jax_utils, traverse_util
39
- from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
40
- from flax.jax_utils import unreplicate
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
43
- from flax.training.common_utils import get_metrics, onehot, shard_prng_key
44
  from jax.experimental import PartitionSpec, maps
45
  from jax.experimental.pjit import pjit
46
  from tqdm import tqdm
@@ -402,14 +400,14 @@ class MetricsLogger:
402
 
403
  def get_all_train_metrics(self, train_metrics, state):
404
  """Make a dict of training metrics to be logged"""
405
- metrics = unreplicate(train_metrics)
406
  # get state parameters
407
  state_dict = {
408
- k.split("_")[-1]: unreplicate(getattr(state, k))
409
  for k in ["epoch", "train_time", "train_samples"]
410
  }
411
  # timing metrics
412
- new_step = int(unreplicate(state.step))
413
  new_time = time.perf_counter()
414
  if new_step > self.step:
415
  time_per_step = (new_time - self.time) / (new_step - self.step)
@@ -551,7 +549,7 @@ def main():
551
 
552
  # Initialize our training
553
  rng = jax.random.PRNGKey(training_args.seed_model)
554
- rng, *dropout_rng = jax.random.split(rng, num=training_args.dp_devices + 1)
555
 
556
  # Store some constant
557
  num_epochs = training_args.num_train_epochs
@@ -681,34 +679,39 @@ def main():
681
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
682
  mesh = maps.Mesh(devices, ("batch", "mp"))
683
 
684
- # move params & init opt_state over specified devices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  with maps.mesh(mesh.devices, mesh.axis_names):
 
686
  params, opt_state = pjit(
687
  lambda x: (x, optimizer.init(x)),
688
  in_axis_resources=None,
689
  out_axis_resources=(param_spec, opt_state_spec),
690
  )(freeze(model.params))
691
-
692
- # Setup train state
693
- state = TrainState(
694
- apply_fn=model.__call__,
695
- params=params,
696
- opt_state=opt_state,
697
- tx=optimizer,
698
- dropout_rng=dropout_rng,
699
- step=0,
700
- )
701
-
702
- # create PartitionSpec for state
703
- state_spec = {
704
- "params": param_spec,
705
- "opt_state": opt_state_spec,
706
- "dropout_rng": PartitionSpec("batch", None),
707
- "epoch": None,
708
- "step": None,
709
- "train_samples": None,
710
- "train_time": None,
711
- }
712
 
713
  if training_args.resume_from_checkpoint is not None:
714
  # restore optimizer state and other parameters
 
34
  import optax
35
  import transformers
36
  from datasets import Dataset
37
+ from distributed_shampoo import GraftingType, distributed_shampoo
38
+ from flax.core.frozen_dict import freeze
 
 
39
  from flax.serialization import from_bytes, to_bytes
40
  from flax.training import train_state
41
+ from flax.training.common_utils import get_metrics, onehot
42
  from jax.experimental import PartitionSpec, maps
43
  from jax.experimental.pjit import pjit
44
  from tqdm import tqdm
 
400
 
401
  def get_all_train_metrics(self, train_metrics, state):
402
  """Make a dict of training metrics to be logged"""
403
+ metrics = train_metrics
404
  # get state parameters
405
  state_dict = {
406
+ k.split("_")[-1]: getattr(state, k)
407
  for k in ["epoch", "train_time", "train_samples"]
408
  }
409
  # timing metrics
410
+ new_step = int(state.step)
411
  new_time = time.perf_counter()
412
  if new_step > self.step:
413
  time_per_step = (new_time - self.time) / (new_step - self.step)
 
549
 
550
  # Initialize our training
551
  rng = jax.random.PRNGKey(training_args.seed_model)
552
+ rng, dropout_rng = jax.random.split(rng)
553
 
554
  # Store some constant
555
  num_epochs = training_args.num_train_epochs
 
679
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
680
  mesh = maps.Mesh(devices, ("batch", "mp"))
681
 
682
+ # Setup train state
683
+ def init_state(params, opt_state):
684
+ return TrainState(
685
+ apply_fn=model.__call__,
686
+ tx=optimizer,
687
+ params=params,
688
+ opt_state=opt_state,
689
+ dropout_rng=dropout_rng,
690
+ step=0,
691
+ )
692
+
693
+ state_spec = init_state(param_spec, opt_state_spec)
694
+ state_spec = state_spec.replace(
695
+ dropout_rng=None,
696
+ step=None,
697
+ epoch=None,
698
+ train_time=None,
699
+ train_samples=None,
700
+ )
701
+
702
  with maps.mesh(mesh.devices, mesh.axis_names):
703
+ # move params & init opt_state over specified devices
704
  params, opt_state = pjit(
705
  lambda x: (x, optimizer.init(x)),
706
  in_axis_resources=None,
707
  out_axis_resources=(param_spec, opt_state_spec),
708
  )(freeze(model.params))
709
+ # create training state
710
+ state = pjit(
711
+ init_state,
712
+ in_axis_resources=(param_spec, opt_state_spec),
713
+ out_axis_resources=state_spec,
714
+ )(params, opt_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
 
716
  if training_args.resume_from_checkpoint is not None:
717
  # restore optimizer state and other parameters