Update weights and scripts
Browse files- flax_model.msgpack +1 -1
- flax_to_pt.py +3 -1
- opt_state.msgpack +1 -1
- pytorch_model.bin +1 -1
- run_t5.sh +2 -2
- run_t5_mlm_flax_custom_dataset.py +11 -6
- runs/Jul16_09-14-47_t1v-n-0e7426e8-w-0/events.out.tfevents.1626426893.t1v-n-0e7426e8-w-0.21179.3.v2 +3 -0
- runs/Jul16_11-53-22_t1v-n-0e7426e8-w-0/events.out.tfevents.1626436407.t1v-n-0e7426e8-w-0.23523.3.v2 +3 -0
- tf_model.h5 +1 -1
- training_state.json +1 -1
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891548548
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fc61ad7414d32c991fb5e568ee6d64eb92653a6f68e7386a7a1c1fd43973a45
|
3 |
size 891548548
|
flax_to_pt.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
from transformers import T5ForConditionalGeneration
|
2 |
|
3 |
pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
|
4 |
pt_model.save_pretrained(".")
|
|
|
|
|
|
1 |
+
from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
|
2 |
|
3 |
pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
|
4 |
pt_model.save_pretrained(".")
|
5 |
+
tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True)
|
6 |
+
tf_model.save_pretrained(".")
|
opt_state.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1985609
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3457cc629735027c0d39cfcc9c4978f8180617df4023786ca1c542c79c466335
|
3 |
size 1985609
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891650495
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1b04c56abcc3a5bd4d7e871c7d017f44ab5b75af1c4adcc30c205da5fc5ede1
|
3 |
size 891650495
|
run_t5.sh
CHANGED
@@ -50,9 +50,9 @@ while true; do
|
|
50 |
--per_device_train_batch_size="16" \
|
51 |
--per_device_eval_batch_size="16" \
|
52 |
--dtype="bfloat16" \
|
53 |
-
--learning_rate="1e-
|
54 |
--overwrite_output_dir \
|
55 |
-
--num_train_epochs="
|
56 |
--logging_steps="50" \
|
57 |
--save_steps="500" \
|
58 |
--eval_steps="5000" \
|
|
|
50 |
--per_device_train_batch_size="16" \
|
51 |
--per_device_eval_batch_size="16" \
|
52 |
--dtype="bfloat16" \
|
53 |
+
--learning_rate="1e-3" \
|
54 |
--overwrite_output_dir \
|
55 |
+
--num_train_epochs="1" \
|
56 |
--logging_steps="50" \
|
57 |
--save_steps="500" \
|
58 |
--eval_steps="5000" \
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
@@ -583,7 +583,7 @@ if __name__ == "__main__":
|
|
583 |
|
584 |
return train, val
|
585 |
|
586 |
-
train, val = train_val_files()
|
587 |
|
588 |
load_grouped = True
|
589 |
|
@@ -649,7 +649,7 @@ if __name__ == "__main__":
|
|
649 |
logger.info("Loading tokenized and grouped dataset")
|
650 |
tokenized_datasets = DatasetDict.load_from_disk("/home/yeb/grouped_datasets")
|
651 |
logger.info("Setting max validation examples to 500")
|
652 |
-
tokenized_datasets['validation'] = tokenized_datasets['validation'].select(range(
|
653 |
else:
|
654 |
if training_args.do_train:
|
655 |
column_names = datasets["train"].column_names
|
@@ -906,11 +906,16 @@ if __name__ == "__main__":
|
|
906 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
907 |
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
908 |
# skip to the step from which we are resuming
|
909 |
-
|
910 |
-
|
911 |
|
912 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
913 |
-
|
|
|
|
|
|
|
|
|
|
|
914 |
|
915 |
# Model forward
|
916 |
model_inputs = shard(model_inputs.data)
|
@@ -926,7 +931,7 @@ if __name__ == "__main__":
|
|
926 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
927 |
|
928 |
epochs.write(
|
929 |
-
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
930 |
)
|
931 |
|
932 |
train_metrics = []
|
|
|
583 |
|
584 |
return train, val
|
585 |
|
586 |
+
# train, val = train_val_files()
|
587 |
|
588 |
load_grouped = True
|
589 |
|
|
|
649 |
logger.info("Loading tokenized and grouped dataset")
|
650 |
tokenized_datasets = DatasetDict.load_from_disk("/home/yeb/grouped_datasets")
|
651 |
logger.info("Setting max validation examples to 500")
|
652 |
+
tokenized_datasets['validation'] = tokenized_datasets['validation'].select(range(1000))
|
653 |
else:
|
654 |
if training_args.do_train:
|
655 |
column_names = datasets["train"].column_names
|
|
|
906 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
907 |
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
908 |
# skip to the step from which we are resuming
|
909 |
+
if cur_step < resume_step:
|
910 |
+
continue
|
911 |
|
912 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
913 |
+
try:
|
914 |
+
model_inputs = data_collator(samples)
|
915 |
+
except ValueError as e:
|
916 |
+
logger.warning(str(e))
|
917 |
+
logger.info(f"Continuing with the next batch")
|
918 |
+
continue
|
919 |
|
920 |
# Model forward
|
921 |
model_inputs = shard(model_inputs.data)
|
|
|
931 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
932 |
|
933 |
epochs.write(
|
934 |
+
f"Step... ({cur_step} ({cur_step+resume_step}| Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
935 |
)
|
936 |
|
937 |
train_metrics = []
|
runs/Jul16_09-14-47_t1v-n-0e7426e8-w-0/events.out.tfevents.1626426893.t1v-n-0e7426e8-w-0.21179.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6bd1363991ec767bd11c50387d5c37ce29cd84fba0daa9bd9fbbe1bc246a5d6
|
3 |
+
size 865193
|
runs/Jul16_11-53-22_t1v-n-0e7426e8-w-0/events.out.tfevents.1626436407.t1v-n-0e7426e8-w-0.23523.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6246ff2cb2ae46428e8c6faadbbbedefd5e718271f8db555ce4bc45d1f5a8d0e
|
3 |
+
size 40
|
tf_model.h5
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 892067416
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ca091f719f88d0c460cb709fead1521082e46ac9b1d9873a06e65bb0ca2d94c
|
3 |
size 892067416
|
training_state.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"step":
|
|
|
1 |
+
{"step": 54001}
|