feat(train): use new HF _do_init api
Browse files- src/dalle_mini/model/modeling.py +4 -3
- tools/train/train.py +22 -20
src/dalle_mini/model/modeling.py
CHANGED
@@ -1330,10 +1330,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
1330 |
|
1331 |
config_class = DalleBartConfig
|
1332 |
|
1333 |
-
|
1334 |
-
|
|
|
1335 |
num_params = jax.tree_map(
|
1336 |
-
lambda param: param.size, flatten_dict(unfreeze(
|
1337 |
).values()
|
1338 |
return sum(list(num_params))
|
1339 |
|
|
|
1330 |
|
1331 |
config_class = DalleBartConfig
|
1332 |
|
1333 |
+
def num_params(self, params=None):
|
1334 |
+
if params is None:
|
1335 |
+
params = self.params
|
1336 |
num_params = jax.tree_map(
|
1337 |
+
lambda param: param.size, flatten_dict(unfreeze(params))
|
1338 |
).values()
|
1339 |
return sum(list(num_params))
|
1340 |
|
tools/train/train.py
CHANGED
@@ -672,12 +672,12 @@ def main():
|
|
672 |
|
673 |
# Load or create new model
|
674 |
if model_args.model_name_or_path:
|
675 |
-
model = DalleBart.from_pretrained(
|
676 |
model_args.model_name_or_path,
|
677 |
config=config,
|
678 |
seed=training_args.seed_model,
|
679 |
dtype=getattr(jnp, model_args.dtype),
|
680 |
-
|
681 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
682 |
)
|
683 |
else:
|
@@ -685,17 +685,19 @@ def main():
|
|
685 |
config,
|
686 |
seed=training_args.seed_model,
|
687 |
dtype=getattr(jnp, model_args.dtype),
|
688 |
-
|
689 |
)
|
|
|
|
|
690 |
|
691 |
# get model metadata
|
692 |
model_metadata = model_args.get_metadata()
|
693 |
|
694 |
# get PartitionSpec for model params (required to be a dict)
|
695 |
-
param_spec = set_partitions(
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
|
700 |
# Load tokenizer
|
701 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
@@ -736,7 +738,7 @@ def main():
|
|
736 |
num_train_steps = (
|
737 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
738 |
)
|
739 |
-
num_params = model.num_params
|
740 |
|
741 |
logger.info("***** Running training *****")
|
742 |
logger.info(f" Num examples = {len_train_dataset}")
|
@@ -875,7 +877,7 @@ def main():
|
|
875 |
|
876 |
optimizer = {}
|
877 |
opt_fn = {}
|
878 |
-
for k, p in split_params(
|
879 |
if "scanned" in k:
|
880 |
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
|
881 |
optimizer[k] = opt.init(p)
|
@@ -891,7 +893,7 @@ def main():
|
|
891 |
b2=training_args.beta2,
|
892 |
eps=training_args.adam_epsilon,
|
893 |
)
|
894 |
-
optimizer = {k: optimizer for k in split_params(
|
895 |
|
896 |
elif training_args.optim == "adafactor":
|
897 |
# We use the default parameters here to initialize adafactor,
|
@@ -900,13 +902,13 @@ def main():
|
|
900 |
learning_rate=learning_rate_fn,
|
901 |
clipping_threshold=training_args.max_grad_norm,
|
902 |
)
|
903 |
-
optimizer = {k: optimizer for k in split_params(
|
904 |
|
905 |
# get PartitionSpec for optimizer state
|
906 |
def get_opt_state_spec_and_shape():
|
907 |
# get opt_state shape without actual init
|
908 |
opt_state_shape = {}
|
909 |
-
for k, p in split_params(
|
910 |
if "scanned" not in k:
|
911 |
opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
|
912 |
else:
|
@@ -914,7 +916,7 @@ def main():
|
|
914 |
|
915 |
if training_args.optim == "adafactor":
|
916 |
# factorized state must be replicated (rank different than params)
|
917 |
-
opt_state_spec = {k: None for k in split_params(
|
918 |
|
919 |
elif training_args.optim in ["adam", "distributed_shampoo"]:
|
920 |
|
@@ -926,9 +928,9 @@ def main():
|
|
926 |
# other variables such as count
|
927 |
return None
|
928 |
|
929 |
-
split_spec = split_params(set_partitions(
|
930 |
opt_state_spec = {}
|
931 |
-
for k, p in split_params(
|
932 |
if "scanned" in k:
|
933 |
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
|
934 |
if training_args.optim == "adam":
|
@@ -982,12 +984,12 @@ def main():
|
|
982 |
|
983 |
# init params if not available yet
|
984 |
def maybe_init_params(params):
|
985 |
-
if
|
986 |
# model params are correctly loaded
|
987 |
return params
|
988 |
else:
|
989 |
# params have not been initialized yet
|
990 |
-
return model.init_weights()
|
991 |
|
992 |
with mesh:
|
993 |
logger.info(" Creating state")
|
@@ -1008,7 +1010,7 @@ def main():
|
|
1008 |
else None,
|
1009 |
out_axis_resources=state_spec,
|
1010 |
donate_argnums=(0,),
|
1011 |
-
)(
|
1012 |
|
1013 |
else:
|
1014 |
# load opt_state
|
@@ -1038,13 +1040,13 @@ def main():
|
|
1038 |
),
|
1039 |
out_axis_resources=state_spec,
|
1040 |
donate_argnums=(0, 1),
|
1041 |
-
)(
|
1042 |
|
1043 |
# remove opt_state from CPU
|
1044 |
del opt_state
|
1045 |
|
1046 |
# free CPU memory
|
1047 |
-
del
|
1048 |
|
1049 |
# define batch specs
|
1050 |
batch_spec = PartitionSpec("dp")
|
|
|
672 |
|
673 |
# Load or create new model
|
674 |
if model_args.model_name_or_path:
|
675 |
+
model, params = DalleBart.from_pretrained(
|
676 |
model_args.model_name_or_path,
|
677 |
config=config,
|
678 |
seed=training_args.seed_model,
|
679 |
dtype=getattr(jnp, model_args.dtype),
|
680 |
+
_do_init=False, # we overwrite them with loaded checkpoint
|
681 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
682 |
)
|
683 |
else:
|
|
|
685 |
config,
|
686 |
seed=training_args.seed_model,
|
687 |
dtype=getattr(jnp, model_args.dtype),
|
688 |
+
_do_init=False,
|
689 |
)
|
690 |
+
params = None
|
691 |
+
params_shape = model.params_shape_tree
|
692 |
|
693 |
# get model metadata
|
694 |
model_metadata = model_args.get_metadata()
|
695 |
|
696 |
# get PartitionSpec for model params (required to be a dict)
|
697 |
+
param_spec = set_partitions(params_shape, model.config.use_scan)
|
698 |
+
params_shape = freeze(params_shape)
|
699 |
+
if params is not None:
|
700 |
+
params = freeze(params)
|
701 |
|
702 |
# Load tokenizer
|
703 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
|
|
738 |
num_train_steps = (
|
739 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
740 |
)
|
741 |
+
num_params = model.num_params(params_shape)
|
742 |
|
743 |
logger.info("***** Running training *****")
|
744 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
877 |
|
878 |
optimizer = {}
|
879 |
opt_fn = {}
|
880 |
+
for k, p in split_params(params_shape).items():
|
881 |
if "scanned" in k:
|
882 |
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
|
883 |
optimizer[k] = opt.init(p)
|
|
|
893 |
b2=training_args.beta2,
|
894 |
eps=training_args.adam_epsilon,
|
895 |
)
|
896 |
+
optimizer = {k: optimizer for k in split_params(params_shape)}
|
897 |
|
898 |
elif training_args.optim == "adafactor":
|
899 |
# We use the default parameters here to initialize adafactor,
|
|
|
902 |
learning_rate=learning_rate_fn,
|
903 |
clipping_threshold=training_args.max_grad_norm,
|
904 |
)
|
905 |
+
optimizer = {k: optimizer for k in split_params(params_shape)}
|
906 |
|
907 |
# get PartitionSpec for optimizer state
|
908 |
def get_opt_state_spec_and_shape():
|
909 |
# get opt_state shape without actual init
|
910 |
opt_state_shape = {}
|
911 |
+
for k, p in split_params(params_shape).items():
|
912 |
if "scanned" not in k:
|
913 |
opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
|
914 |
else:
|
|
|
916 |
|
917 |
if training_args.optim == "adafactor":
|
918 |
# factorized state must be replicated (rank different than params)
|
919 |
+
opt_state_spec = {k: None for k in split_params(params_shape)}
|
920 |
|
921 |
elif training_args.optim in ["adam", "distributed_shampoo"]:
|
922 |
|
|
|
928 |
# other variables such as count
|
929 |
return None
|
930 |
|
931 |
+
split_spec = split_params(set_partitions(params_shape, False))
|
932 |
opt_state_spec = {}
|
933 |
+
for k, p in split_params(params_shape).items():
|
934 |
if "scanned" in k:
|
935 |
p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
|
936 |
if training_args.optim == "adam":
|
|
|
984 |
|
985 |
# init params if not available yet
|
986 |
def maybe_init_params(params):
|
987 |
+
if params is not None:
|
988 |
# model params are correctly loaded
|
989 |
return params
|
990 |
else:
|
991 |
# params have not been initialized yet
|
992 |
+
return model.init_weights(model.key, model.input_shape)
|
993 |
|
994 |
with mesh:
|
995 |
logger.info(" Creating state")
|
|
|
1010 |
else None,
|
1011 |
out_axis_resources=state_spec,
|
1012 |
donate_argnums=(0,),
|
1013 |
+
)(params)
|
1014 |
|
1015 |
else:
|
1016 |
# load opt_state
|
|
|
1040 |
),
|
1041 |
out_axis_resources=state_spec,
|
1042 |
donate_argnums=(0, 1),
|
1043 |
+
)(params, opt_state)
|
1044 |
|
1045 |
# remove opt_state from CPU
|
1046 |
del opt_state
|
1047 |
|
1048 |
# free CPU memory
|
1049 |
+
del params, opt_state_spec, opt_state_shape
|
1050 |
|
1051 |
# define batch specs
|
1052 |
batch_spec = PartitionSpec("dp")
|