feat: support LR offset (#174)
Browse files- tools/train/train.py +17 -16
tools/train/train.py
CHANGED
@@ -119,7 +119,7 @@ class ModelArguments:
|
|
119 |
), "Restoring state only available with W&B artifact reference"
|
120 |
|
121 |
def get_metadata(self):
|
122 |
-
if self.
|
123 |
if jax.process_index() == 0:
|
124 |
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
125 |
else:
|
@@ -413,11 +413,9 @@ class TrainingArguments:
|
|
413 |
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
414 |
},
|
415 |
)
|
416 |
-
|
417 |
-
default=
|
418 |
-
metadata={
|
419 |
-
"help": "Whether to offset the learning rate function with current step when resuming a run."
|
420 |
-
},
|
421 |
)
|
422 |
logging_steps: int = field(
|
423 |
default=40, metadata={"help": "Log every X updates steps."}
|
@@ -796,14 +794,14 @@ def main():
|
|
796 |
end_value=training_args.learning_rate,
|
797 |
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
798 |
)
|
799 |
-
# offset step when resuming
|
800 |
last_boundary = training_args.warmup_steps
|
801 |
-
|
|
|
802 |
warmup_fn = optax.join_schedules(
|
803 |
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
804 |
-
boundaries=[
|
805 |
)
|
806 |
-
last_boundary +=
|
807 |
if training_args.lr_decay is None:
|
808 |
return warmup_fn
|
809 |
elif training_args.lr_decay == "linear":
|
@@ -1005,6 +1003,14 @@ def main():
|
|
1005 |
|
1006 |
with mesh:
|
1007 |
logger.info(" Creating state")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1008 |
if not model_args.restore_state:
|
1009 |
|
1010 |
def init_state(params):
|
@@ -1013,6 +1019,7 @@ def main():
|
|
1013 |
tx=optimizer,
|
1014 |
params=maybe_init_params(params),
|
1015 |
dropout_rng=dropout_rng,
|
|
|
1016 |
)
|
1017 |
|
1018 |
state = pjit(
|
@@ -1028,12 +1035,6 @@ def main():
|
|
1028 |
# load opt_state
|
1029 |
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
|
1030 |
|
1031 |
-
# restore other attributes
|
1032 |
-
attr_state = {
|
1033 |
-
k: model_metadata[k]
|
1034 |
-
for k in ["step", "epoch", "train_time", "train_samples"]
|
1035 |
-
}
|
1036 |
-
|
1037 |
def restore_state(params, opt_state):
|
1038 |
return TrainState(
|
1039 |
apply_fn=model.__call__,
|
|
|
119 |
), "Restoring state only available with W&B artifact reference"
|
120 |
|
121 |
def get_metadata(self):
|
122 |
+
if ":" in self.model_name_or_path:
|
123 |
if jax.process_index() == 0:
|
124 |
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
125 |
else:
|
|
|
413 |
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
414 |
},
|
415 |
)
|
416 |
+
lr_offset: int = field(
|
417 |
+
default=0,
|
418 |
+
metadata={"help": "Number of steps to offset learning rate and keep it at 0."},
|
|
|
|
|
419 |
)
|
420 |
logging_steps: int = field(
|
421 |
default=40, metadata={"help": "Log every X updates steps."}
|
|
|
794 |
end_value=training_args.learning_rate,
|
795 |
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
796 |
)
|
|
|
797 |
last_boundary = training_args.warmup_steps
|
798 |
+
# offset step when resuming
|
799 |
+
if training_args.lr_offset:
|
800 |
warmup_fn = optax.join_schedules(
|
801 |
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
802 |
+
boundaries=[training_args.lr_offset],
|
803 |
)
|
804 |
+
last_boundary += training_args.lr_offset
|
805 |
if training_args.lr_decay is None:
|
806 |
return warmup_fn
|
807 |
elif training_args.lr_decay == "linear":
|
|
|
1003 |
|
1004 |
with mesh:
|
1005 |
logger.info(" Creating state")
|
1006 |
+
|
1007 |
+
# restore metadata
|
1008 |
+
attr_state = {}
|
1009 |
+
keys = ["train_time", "train_samples"]
|
1010 |
+
if model_args.restore_state:
|
1011 |
+
keys += ["step", "epoch"]
|
1012 |
+
attr_state = {k: v for k, v in model_metadata.items() if k in keys}
|
1013 |
+
|
1014 |
if not model_args.restore_state:
|
1015 |
|
1016 |
def init_state(params):
|
|
|
1019 |
tx=optimizer,
|
1020 |
params=maybe_init_params(params),
|
1021 |
dropout_rng=dropout_rng,
|
1022 |
+
**attr_state,
|
1023 |
)
|
1024 |
|
1025 |
state = pjit(
|
|
|
1035 |
# load opt_state
|
1036 |
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
|
1037 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1038 |
def restore_state(params, opt_state):
|
1039 |
return TrainState(
|
1040 |
apply_fn=model.__call__,
|