Spaces:
Running
Running
fix(seq2seq): opt_state from ckpt + limit cache
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -20,10 +20,6 @@ Script adapted from run_summarization_flax.py
|
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
import os
|
23 |
-
# set a common huggingface cache folder (used with datasets and transformers) and wandb cache folder (used with artifacts)
|
24 |
-
os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
|
25 |
-
os.environ['WANDB_CACHE_DIR'] = '/data/wandb/' # required before importing wandb
|
26 |
-
|
27 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
28 |
import sys
|
29 |
import time
|
@@ -442,6 +438,7 @@ def main():
|
|
442 |
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
443 |
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
444 |
opt_state = from_bytes(state.opt_state, f.read())
|
|
|
445 |
|
446 |
# restore steps
|
447 |
if (Path(artifact_dir) / 'training_state.json').exists():
|
@@ -836,6 +833,10 @@ def main():
|
|
836 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
837 |
wandb.run.log_artifact(artifact)
|
838 |
|
|
|
|
|
|
|
|
|
839 |
# save to the hub
|
840 |
if training_args.push_to_hub:
|
841 |
model.save_pretrained(
|
@@ -866,7 +867,7 @@ def main():
|
|
866 |
# log metrics
|
867 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
868 |
|
869 |
-
if global_step % training_args.eval_steps == 0:
|
870 |
run_evaluation()
|
871 |
|
872 |
if global_step % data_args.save_model_steps == 0:
|
|
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
import os
|
|
|
|
|
|
|
|
|
23 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
24 |
import sys
|
25 |
import time
|
|
|
438 |
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
439 |
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
440 |
opt_state = from_bytes(state.opt_state, f.read())
|
441 |
+
state.replace(opt_state=opt_state)
|
442 |
|
443 |
# restore steps
|
444 |
if (Path(artifact_dir) / 'training_state.json').exists():
|
|
|
833 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
834 |
wandb.run.log_artifact(artifact)
|
835 |
|
836 |
+
# save some space
|
837 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
838 |
+
c.cleanup(wandb.util.from_human_size("15GB"))
|
839 |
+
|
840 |
# save to the hub
|
841 |
if training_args.push_to_hub:
|
842 |
model.save_pretrained(
|
|
|
867 |
# log metrics
|
868 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
869 |
|
870 |
+
if training_args.eval_steps and global_step % training_args.eval_steps == 0:
|
871 |
run_evaluation()
|
872 |
|
873 |
if global_step % data_args.save_model_steps == 0:
|