Commit
·
4e5c598
1
Parent(s):
fb5ea5a
fix training script
Browse files- added_tokens.json +0 -1
- run.sh +2 -2
- run_speech_recognition_ctc.py +13 -5
- vocab.json +1 -1
added_tokens.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"<s>": 33, "</s>": 34}
|
|
|
|
run.sh
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
python run_speech_recognition_ctc.py \
|
2 |
-
--dataset_name="mozilla-foundation/
|
3 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
4 |
--dataset_config_name="sv-SE,distant_channel" \
|
5 |
-
--train_split_name="
|
6 |
--eval_split_name="test,None" \
|
7 |
--output_dir="./" \
|
8 |
--overwrite_output_dir \
|
|
|
1 |
python run_speech_recognition_ctc.py \
|
2 |
+
--dataset_name="mozilla-foundation/common_voice_8_0,marinone94/nst_sv" \
|
3 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
4 |
--dataset_config_name="sv-SE,distant_channel" \
|
5 |
+
--train_split_name="train+validation,train" \
|
6 |
--eval_split_name="test,None" \
|
7 |
--output_dir="./" \
|
8 |
--overwrite_output_dir \
|
run_speech_recognition_ctc.py
CHANGED
@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Union
|
|
28 |
|
29 |
import datasets
|
30 |
import numpy as np
|
|
|
31 |
import torch
|
32 |
import wandb
|
33 |
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
|
@@ -376,7 +377,7 @@ def main():
|
|
376 |
wandb.login()
|
377 |
training_args.report_to = ["wandb"]
|
378 |
training_args.run_name = run_name
|
379 |
-
wandb.init()
|
380 |
except:
|
381 |
pass
|
382 |
|
@@ -480,6 +481,11 @@ def main():
|
|
480 |
other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
|
481 |
raw_datasets["train"].remove_columns(other_columns_train)
|
482 |
|
|
|
|
|
|
|
|
|
|
|
483 |
if training_args.do_eval:
|
484 |
# Multiple datasets might need to be loaded from HF
|
485 |
# It assumes they all follow the common voice format
|
@@ -520,6 +526,11 @@ def main():
|
|
520 |
other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
|
521 |
raw_datasets["eval"].remove_columns(other_columns_eval)
|
522 |
|
|
|
|
|
|
|
|
|
|
|
523 |
# 2. We remove some special characters from the datasets
|
524 |
# that make training complicated and do not help in transcribing the speech
|
525 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
@@ -755,15 +766,12 @@ def main():
|
|
755 |
if data_args.dataset_seed is not None:
|
756 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
757 |
|
758 |
-
# Log sample of datasets
|
759 |
pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
|
760 |
pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
|
761 |
# wandb.log({"train_sample": pd_train})
|
762 |
# wandb.log({"eval_sample": pd_eval})
|
763 |
|
764 |
-
print(pd_train)
|
765 |
-
print(pd_eval)
|
766 |
-
|
767 |
# for large datasets it is advised to run the preprocessing on a
|
768 |
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
769 |
# be a timeout when running the script in distributed mode.
|
|
|
28 |
|
29 |
import datasets
|
30 |
import numpy as np
|
31 |
+
import pandas as pd
|
32 |
import torch
|
33 |
import wandb
|
34 |
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
|
|
|
377 |
wandb.login()
|
378 |
training_args.report_to = ["wandb"]
|
379 |
training_args.run_name = run_name
|
380 |
+
# wandb.init()
|
381 |
except:
|
382 |
pass
|
383 |
|
|
|
481 |
other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
|
482 |
raw_datasets["train"].remove_columns(other_columns_train)
|
483 |
|
484 |
+
# pd_train_head = raw_datasets["train"].select(range(10)).to_pandas()
|
485 |
+
# pd_train_tail = raw_datasets["train"].select(range(raw_datasets["train"].num_rows-10, raw_datasets["train"].num_rows)).to_pandas()
|
486 |
+
# pd_train = pd.concat([pd_train_head, pd_train_tail])
|
487 |
+
# print(pd_train["audio"])
|
488 |
+
|
489 |
if training_args.do_eval:
|
490 |
# Multiple datasets might need to be loaded from HF
|
491 |
# It assumes they all follow the common voice format
|
|
|
526 |
other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
|
527 |
raw_datasets["eval"].remove_columns(other_columns_eval)
|
528 |
|
529 |
+
# pd_eval_head = raw_datasets["eval"].select(range(10)).to_pandas()
|
530 |
+
# pd_eval_tail = raw_datasets["eval"].select(range(raw_datasets["eval"].num_rows-10, raw_datasets["eval"].num_rows)).to_pandas()
|
531 |
+
# pd_eval = pd.concat([pd_eval_head, pd_eval_tail])
|
532 |
+
# print(pd_eval["audio"])
|
533 |
+
|
534 |
# 2. We remove some special characters from the datasets
|
535 |
# that make training complicated and do not help in transcribing the speech
|
536 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
|
|
766 |
if data_args.dataset_seed is not None:
|
767 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
768 |
|
769 |
+
# TODO: Log sample of datasets in the right way (see wandb docs)
|
770 |
pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
|
771 |
pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
|
772 |
# wandb.log({"train_sample": pd_train})
|
773 |
# wandb.log({"eval_sample": pd_eval})
|
774 |
|
|
|
|
|
|
|
775 |
# for large datasets it is advised to run the preprocessing on a
|
776 |
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
777 |
# be a timeout when running the script in distributed mode.
|
vocab.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "
|
|
|
1 |
+
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "\u00e4": 27, "\u00e5": 28, "\u00f6": 29, "|": 0, "[UNK]": 30, "[PAD]": 31}
|