marinone94
commited on
Commit
·
3870354
1
Parent(s):
7d30caa
add more debugging, shuffle dataset max train
Browse files- join_datasets_asr_ctc.py +20 -4
- join_datasets_asr_ctc_run.sh +1 -0
- vocab.json +1 -1
join_datasets_asr_ctc.py
CHANGED
@@ -479,6 +479,11 @@ def load_raw_datasets(training_args, data_args):
|
|
479 |
f"{', '.join(raw_datasets['train'].column_names)}."
|
480 |
)
|
481 |
|
|
|
|
|
|
|
|
|
|
|
482 |
if data_args.max_train_samples is not None:
|
483 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
484 |
other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
|
@@ -771,10 +776,6 @@ def preprocess_audio_datasets(raw_datasets, tokenizer, feature_extractor, traini
|
|
771 |
input_columns=["input_length"],
|
772 |
)
|
773 |
|
774 |
-
# If dataset_seed is set, shuffle train
|
775 |
-
if data_args.dataset_seed is not None:
|
776 |
-
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
777 |
-
|
778 |
# TODO: Log sample of datasets in the right way (see wandb docs)
|
779 |
pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
|
780 |
pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
|
@@ -872,6 +873,21 @@ def main():
|
|
872 |
data_args=data_args
|
873 |
)
|
874 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
875 |
# 7. Next, we can prepare the training.
|
876 |
# Let's use word error rate (WER) as our evaluation metric,
|
877 |
# instantiate a data collator and the trainer
|
|
|
479 |
f"{', '.join(raw_datasets['train'].column_names)}."
|
480 |
)
|
481 |
|
482 |
+
# If dataset_seed is set, shuffle train
|
483 |
+
if data_args.dataset_seed is not None:
|
484 |
+
raw_datasets["train"] = raw_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
485 |
+
|
486 |
+
|
487 |
if data_args.max_train_samples is not None:
|
488 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
489 |
other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
|
|
|
776 |
input_columns=["input_length"],
|
777 |
)
|
778 |
|
|
|
|
|
|
|
|
|
779 |
# TODO: Log sample of datasets in the right way (see wandb docs)
|
780 |
pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
|
781 |
pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
|
|
|
873 |
data_args=data_args
|
874 |
)
|
875 |
|
876 |
+
# Inspect datasets
|
877 |
+
logger.info("Inspect datasets")
|
878 |
+
avg = []
|
879 |
+
std = []
|
880 |
+
import numpy as np
|
881 |
+
for input_ in vectorized_datasets["train"][:10]["input_values"]:
|
882 |
+
avg.append(np.average(input_))
|
883 |
+
std.append(np.std(input_))
|
884 |
+
for input_ in vectorized_datasets["eval"][:10]["input_values"]:
|
885 |
+
avg.append(np.average(input_))
|
886 |
+
std.append(np.std(input_))
|
887 |
+
|
888 |
+
logger.info(f"Average values: {avg}")
|
889 |
+
logger.info(f"Std values: {std}")
|
890 |
+
|
891 |
# 7. Next, we can prepare the training.
|
892 |
# Let's use word error rate (WER) as our evaluation metric,
|
893 |
# instantiate a data collator and the trainer
|
join_datasets_asr_ctc_run.sh
CHANGED
@@ -30,6 +30,7 @@ python join_datasets_asr_ctc.py \
|
|
30 |
--mask_time_length="10" \
|
31 |
--mask_feature_prob="0.25" \
|
32 |
--mask_feature_length="64" \
|
|
|
33 |
--gradient_checkpointing \
|
34 |
--use_auth_token \
|
35 |
--preprocessing_only \
|
|
|
30 |
--mask_time_length="10" \
|
31 |
--mask_feature_prob="0.25" \
|
32 |
--mask_feature_length="64" \
|
33 |
+
--dataset_seed="42" \
|
34 |
--gradient_checkpointing \
|
35 |
--use_auth_token \
|
36 |
--preprocessing_only \
|
vocab.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "
|
|
|
1 |
+
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "\u00e4": 27, "\u00e5": 28, "\u00f6": 29, "|": 0, "[UNK]": 30, "[PAD]": 31}
|