boris commited on
Commit
6b84155
1 Parent(s): f3a8cbb

feat(train): use new HF _do_init api

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -1330,10 +1330,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
1330
 
1331
  config_class = DalleBartConfig
1332
 
1333
- @property
1334
- def num_params(self):
 
1335
  num_params = jax.tree_map(
1336
- lambda param: param.size, flatten_dict(unfreeze(self.params))
1337
  ).values()
1338
  return sum(list(num_params))
1339
 
 
1330
 
1331
  config_class = DalleBartConfig
1332
 
1333
+ def num_params(self, params=None):
1334
+ if params is None:
1335
+ params = self.params
1336
  num_params = jax.tree_map(
1337
+ lambda param: param.size, flatten_dict(unfreeze(params))
1338
  ).values()
1339
  return sum(list(num_params))
1340
 
tools/train/train.py CHANGED
@@ -672,12 +672,12 @@ def main():
672
 
673
  # Load or create new model
674
  if model_args.model_name_or_path:
675
- model = DalleBart.from_pretrained(
676
  model_args.model_name_or_path,
677
  config=config,
678
  seed=training_args.seed_model,
679
  dtype=getattr(jnp, model_args.dtype),
680
- abstract_init=True, # we overwrite them with loaded checkpoint
681
  gradient_checkpointing=training_args.gradient_checkpointing,
682
  )
683
  else:
@@ -685,17 +685,19 @@ def main():
685
  config,
686
  seed=training_args.seed_model,
687
  dtype=getattr(jnp, model_args.dtype),
688
- abstract_init=True,
689
  )
 
 
690
 
691
  # get model metadata
692
  model_metadata = model_args.get_metadata()
693
 
694
  # get PartitionSpec for model params (required to be a dict)
695
- param_spec = set_partitions(model.params, model.config.use_scan)
696
-
697
- # convert params to frozen dict
698
- model._params = freeze(model.params)
699
 
700
  # Load tokenizer
701
  tokenizer = DalleBartTokenizer.from_pretrained(
@@ -736,7 +738,7 @@ def main():
736
  num_train_steps = (
737
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
738
  )
739
- num_params = model.num_params
740
 
741
  logger.info("***** Running training *****")
742
  logger.info(f" Num examples = {len_train_dataset}")
@@ -875,7 +877,7 @@ def main():
875
 
876
  optimizer = {}
877
  opt_fn = {}
878
- for k, p in split_params(model.params).items():
879
  if "scanned" in k:
880
  p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
881
  optimizer[k] = opt.init(p)
@@ -891,7 +893,7 @@ def main():
891
  b2=training_args.beta2,
892
  eps=training_args.adam_epsilon,
893
  )
894
- optimizer = {k: optimizer for k in split_params(model.params)}
895
 
896
  elif training_args.optim == "adafactor":
897
  # We use the default parameters here to initialize adafactor,
@@ -900,13 +902,13 @@ def main():
900
  learning_rate=learning_rate_fn,
901
  clipping_threshold=training_args.max_grad_norm,
902
  )
903
- optimizer = {k: optimizer for k in split_params(model.params)}
904
 
905
  # get PartitionSpec for optimizer state
906
  def get_opt_state_spec_and_shape():
907
  # get opt_state shape without actual init
908
  opt_state_shape = {}
909
- for k, p in split_params(model.params).items():
910
  if "scanned" not in k:
911
  opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
912
  else:
@@ -914,7 +916,7 @@ def main():
914
 
915
  if training_args.optim == "adafactor":
916
  # factorized state must be replicated (rank different than params)
917
- opt_state_spec = {k: None for k in split_params(model.params)}
918
 
919
  elif training_args.optim in ["adam", "distributed_shampoo"]:
920
 
@@ -926,9 +928,9 @@ def main():
926
  # other variables such as count
927
  return None
928
 
929
- split_spec = split_params(set_partitions(model.params, False))
930
  opt_state_spec = {}
931
- for k, p in split_params(model.params).items():
932
  if "scanned" in k:
933
  p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
934
  if training_args.optim == "adam":
@@ -982,12 +984,12 @@ def main():
982
 
983
  # init params if not available yet
984
  def maybe_init_params(params):
985
- if model_args.model_name_or_path:
986
  # model params are correctly loaded
987
  return params
988
  else:
989
  # params have not been initialized yet
990
- return model.init_weights()
991
 
992
  with mesh:
993
  logger.info(" Creating state")
@@ -1008,7 +1010,7 @@ def main():
1008
  else None,
1009
  out_axis_resources=state_spec,
1010
  donate_argnums=(0,),
1011
- )(model.params if model_args.model_name_or_path else None)
1012
 
1013
  else:
1014
  # load opt_state
@@ -1038,13 +1040,13 @@ def main():
1038
  ),
1039
  out_axis_resources=state_spec,
1040
  donate_argnums=(0, 1),
1041
- )(model.params, opt_state)
1042
 
1043
  # remove opt_state from CPU
1044
  del opt_state
1045
 
1046
  # free CPU memory
1047
- del model._params, opt_state_spec, opt_state_shape
1048
 
1049
  # define batch specs
1050
  batch_spec = PartitionSpec("dp")
 
672
 
673
  # Load or create new model
674
  if model_args.model_name_or_path:
675
+ model, params = DalleBart.from_pretrained(
676
  model_args.model_name_or_path,
677
  config=config,
678
  seed=training_args.seed_model,
679
  dtype=getattr(jnp, model_args.dtype),
680
+ _do_init=False, # we overwrite them with loaded checkpoint
681
  gradient_checkpointing=training_args.gradient_checkpointing,
682
  )
683
  else:
 
685
  config,
686
  seed=training_args.seed_model,
687
  dtype=getattr(jnp, model_args.dtype),
688
+ _do_init=False,
689
  )
690
+ params = None
691
+ params_shape = model.params_shape_tree
692
 
693
  # get model metadata
694
  model_metadata = model_args.get_metadata()
695
 
696
  # get PartitionSpec for model params (required to be a dict)
697
+ param_spec = set_partitions(params_shape, model.config.use_scan)
698
+ params_shape = freeze(params_shape)
699
+ if params is not None:
700
+ params = freeze(params)
701
 
702
  # Load tokenizer
703
  tokenizer = DalleBartTokenizer.from_pretrained(
 
738
  num_train_steps = (
739
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
740
  )
741
+ num_params = model.num_params(params_shape)
742
 
743
  logger.info("***** Running training *****")
744
  logger.info(f" Num examples = {len_train_dataset}")
 
877
 
878
  optimizer = {}
879
  opt_fn = {}
880
+ for k, p in split_params(params_shape).items():
881
  if "scanned" in k:
882
  p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
883
  optimizer[k] = opt.init(p)
 
893
  b2=training_args.beta2,
894
  eps=training_args.adam_epsilon,
895
  )
896
+ optimizer = {k: optimizer for k in split_params(params_shape)}
897
 
898
  elif training_args.optim == "adafactor":
899
  # We use the default parameters here to initialize adafactor,
 
902
  learning_rate=learning_rate_fn,
903
  clipping_threshold=training_args.max_grad_norm,
904
  )
905
+ optimizer = {k: optimizer for k in split_params(params_shape)}
906
 
907
  # get PartitionSpec for optimizer state
908
  def get_opt_state_spec_and_shape():
909
  # get opt_state shape without actual init
910
  opt_state_shape = {}
911
+ for k, p in split_params(params_shape).items():
912
  if "scanned" not in k:
913
  opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
914
  else:
 
916
 
917
  if training_args.optim == "adafactor":
918
  # factorized state must be replicated (rank different than params)
919
+ opt_state_spec = {k: None for k in split_params(params_shape)}
920
 
921
  elif training_args.optim in ["adam", "distributed_shampoo"]:
922
 
 
928
  # other variables such as count
929
  return None
930
 
931
+ split_spec = split_params(set_partitions(params_shape, False))
932
  opt_state_spec = {}
933
+ for k, p in split_params(params_shape).items():
934
  if "scanned" in k:
935
  p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
936
  if training_args.optim == "adam":
 
984
 
985
  # init params if not available yet
986
  def maybe_init_params(params):
987
+ if params is not None:
988
  # model params are correctly loaded
989
  return params
990
  else:
991
  # params have not been initialized yet
992
+ return model.init_weights(model.key, model.input_shape)
993
 
994
  with mesh:
995
  logger.info(" Creating state")
 
1010
  else None,
1011
  out_axis_resources=state_spec,
1012
  donate_argnums=(0,),
1013
+ )(params)
1014
 
1015
  else:
1016
  # load opt_state
 
1040
  ),
1041
  out_axis_resources=state_spec,
1042
  donate_argnums=(0, 1),
1043
+ )(params, opt_state)
1044
 
1045
  # remove opt_state from CPU
1046
  del opt_state
1047
 
1048
  # free CPU memory
1049
+ del params, opt_state_spec, opt_state_shape
1050
 
1051
  # define batch specs
1052
  batch_spec = PartitionSpec("dp")