upload train run_3 artifacts
Browse files- run_3/logs/train_20221221-180439.log +0 -0
- run_3/readme.md +24 -0
- run_3/src/bash_runners/eval_cv11_test.sh +12 -0
- run_3/src/bash_runners/eval_fleurs_test.sh +12 -0
- run_3/src/bash_runners/run_base.sh +44 -0
- run_3/src/bash_runners/run_small.sh +47 -0
- run_3/src/bash_runners/run_tiny_debug.sh +44 -0
- run_3/src/belarusian_text_normalizer.py +41 -0
- run_3/src/custom_trainer.py +53 -0
- run_3/src/readme.md +244 -0
- run_3/src/requirements.txt +9 -0
- run_3/src/run_eval_whisper_streaming.py +219 -0
- run_3/src/run_speech_recognition_seq2seq_streaming.py +774 -0
- run_3/src/setup_env.sh +25 -0
- run_3/tensorboard_logs/Dec21_18-04-41_129-146-110-116/1671647715.531084/events.out.tfevents.1671647715.129-146-110-116.757634.1 +3 -0
- run_3/tensorboard_logs/Dec21_18-04-41_129-146-110-116/events.out.tfevents.1671647715.129-146-110-116.757634.0 +3 -0
- run_3/tensorboard_logs/Dec21_18-04-41_129-146-110-116/events.out.tfevents.1671730045.129-146-110-116.757634.2 +3 -0
- run_3/trainer_state.json +805 -0
run_3/logs/train_20221221-180439.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
run_3/readme.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Fine-tuning run 3
|
2 |
+
|
3 |
+
Tried to improve model fine-tuned during run 1.
|
4 |
+
|
5 |
+
Checkpoint used: checkpoint-12000
|
6 |
+
|
7 |
+
* Trained for 6000 steps
|
8 |
+
* Used custom Learning Rate scheduler initialized in: `custom_trainer.Seq2SeqTrainerCustomLinearScheduler`:
|
9 |
+
* `--learning_rate="3e-5"`
|
10 |
+
* `--learning_rate_end="1e-5"`
|
11 |
+
* no warmup was used
|
12 |
+
* no WER improvements compared to checkpoint-12000 of run 1
|
13 |
+
* using `seed=43`
|
14 |
+
* do not upload checkpoints from that run
|
15 |
+
* uploading src, logs, tensorboard logs, trainer_state
|
16 |
+
|
17 |
+
## Advices
|
18 |
+
* I guess, we need to use warmup when resuming training and increasing LR compared to the last LR in previous run
|
19 |
+
* need to set number of steps > 6000. because model improved WER veeery slowly
|
20 |
+
* can use original Mozilla Common Voice dataset instead of a HuggingFace's one.<br>
|
21 |
+
the reason is that original contains multiple voicings of same sentence -
|
22 |
+
so there is at least twice as more data.<br>
|
23 |
+
to use this "additional" data, train, validation, test sets need to be enlarged using `validated` set -
|
24 |
+
the one that is absent in HuggingFace's CV11 dataset
|
run_3/src/bash_runners/eval_cv11_test.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python src/run_eval_whisper_streaming.py \
|
2 |
+
--model_id="$1" \
|
3 |
+
--language="be" \
|
4 |
+
--dataset="mozilla-foundation/common_voice_11_0" \
|
5 |
+
--config="be" \
|
6 |
+
--split="test" \
|
7 |
+
--text_column="sentence" \
|
8 |
+
--device="0" \
|
9 |
+
--batch_size="32" \
|
10 |
+
--streaming="True" \
|
11 |
+
--push_to_hub="True" \
|
12 |
+
--save_predictions="True"
|
run_3/src/bash_runners/eval_fleurs_test.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python src/run_eval_whisper_streaming.py \
|
2 |
+
--model_id="$1" \
|
3 |
+
--language="be" \
|
4 |
+
--dataset="google/fleurs" \
|
5 |
+
--config="be_by" \
|
6 |
+
--split="test" \
|
7 |
+
--text_column="raw_transcription" \
|
8 |
+
--device="0" \
|
9 |
+
--batch_size="16" \
|
10 |
+
--streaming="True" \
|
11 |
+
--push_to_hub="True" \
|
12 |
+
--save_predictions="True"
|
run_3/src/bash_runners/run_base.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python src/run_speech_recognition_seq2seq_streaming.py \
|
2 |
+
--model_name_or_path="openai/whisper-base" \
|
3 |
+
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
4 |
+
--dataset_config_name="be" \
|
5 |
+
--language="be" \
|
6 |
+
--train_split_name="train" \
|
7 |
+
--eval_split_name="validation" \
|
8 |
+
--model_index_name="Whisper Base Belarusian" \
|
9 |
+
\
|
10 |
+
--max_steps="6000" \
|
11 |
+
--output_dir="./" \
|
12 |
+
--per_device_train_batch_size="64" \
|
13 |
+
--per_device_eval_batch_size="32" \
|
14 |
+
--logging_steps="50" \
|
15 |
+
--logging_first_step \
|
16 |
+
--learning_rate="1e-4" \
|
17 |
+
--warmup_steps="500" \
|
18 |
+
--evaluation_strategy="steps" \
|
19 |
+
--eval_steps="1000" \
|
20 |
+
--save_strategy="steps" \
|
21 |
+
--save_steps="1000" \
|
22 |
+
--gradient_checkpointing \
|
23 |
+
--fp16 \
|
24 |
+
\
|
25 |
+
--shuffle_buffer_size="500" \
|
26 |
+
--generation_max_length="225" \
|
27 |
+
--max_duration_in_seconds="30" \
|
28 |
+
--text_column_name="sentence" \
|
29 |
+
--freeze_feature_encoder="False" \
|
30 |
+
--report_to="tensorboard" \
|
31 |
+
--metric_for_best_model="wer" \
|
32 |
+
--greater_is_better="False" \
|
33 |
+
--load_best_model_at_end \
|
34 |
+
\
|
35 |
+
--do_train \
|
36 |
+
--do_eval \
|
37 |
+
--ignore_data_skip \
|
38 |
+
--predict_with_generate \
|
39 |
+
--do_normalize_eval \
|
40 |
+
--streaming_train="True" \
|
41 |
+
--streaming_eval="False" \
|
42 |
+
--use_auth_token \
|
43 |
+
--push_to_hub \
|
44 |
+
--hub_model_id="ales/whisper-base-belarusian"
|
run_3/src/bash_runners/run_small.sh
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p logs
|
2 |
+
|
3 |
+
python src/run_speech_recognition_seq2seq_streaming.py \
|
4 |
+
--model_name_or_path="ales/whisper-small-belarusian" \
|
5 |
+
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
6 |
+
--dataset_config_name="be" \
|
7 |
+
--language="be" \
|
8 |
+
--train_split_name="train" \
|
9 |
+
--eval_split_name="validation" \
|
10 |
+
--model_index_name="Whisper Small Belarusian" \
|
11 |
+
\
|
12 |
+
--max_steps="6000" \
|
13 |
+
--output_dir="./" \
|
14 |
+
--per_device_train_batch_size="64" \
|
15 |
+
--per_device_eval_batch_size="32" \
|
16 |
+
--logging_steps="50" \
|
17 |
+
--logging_first_step \
|
18 |
+
--learning_rate="3e-5" \
|
19 |
+
--learning_rate_end="1e-5" \
|
20 |
+
--warmup_steps="0" \
|
21 |
+
--evaluation_strategy="steps" \
|
22 |
+
--eval_steps="1000" \
|
23 |
+
--save_strategy="steps" \
|
24 |
+
--save_steps="1000" \
|
25 |
+
--gradient_checkpointing \
|
26 |
+
--fp16 \
|
27 |
+
\
|
28 |
+
--shuffle_buffer_size="500" \
|
29 |
+
--generation_max_length="225" \
|
30 |
+
--max_duration_in_seconds="30" \
|
31 |
+
--text_column_name="sentence" \
|
32 |
+
--freeze_feature_encoder="False" \
|
33 |
+
--report_to="tensorboard" \
|
34 |
+
--metric_for_best_model="wer" \
|
35 |
+
--greater_is_better="False" \
|
36 |
+
--load_best_model_at_end \
|
37 |
+
\
|
38 |
+
--do_train \
|
39 |
+
--do_eval \
|
40 |
+
--ignore_data_skip \
|
41 |
+
--predict_with_generate \
|
42 |
+
--do_normalize_eval \
|
43 |
+
--streaming_train="True" \
|
44 |
+
--streaming_eval="False" \
|
45 |
+
--seed="43" \
|
46 |
+
--use_auth_token \
|
47 |
+
--push_to_hub="False" 2>&1 | tee "logs/train_$(date +"%Y%m%d-%H%M%S").log"
|
run_3/src/bash_runners/run_tiny_debug.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python src/run_speech_recognition_seq2seq_streaming.py \
|
2 |
+
--model_name_or_path="openai/whisper-tiny" \
|
3 |
+
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
4 |
+
--dataset_config_name="be" \
|
5 |
+
--language="be" \
|
6 |
+
--train_split_name="train" \
|
7 |
+
--eval_split_name="validation" \
|
8 |
+
--model_index_name="Whisper Tiny Belarusian" \
|
9 |
+
\
|
10 |
+
--max_steps="500" \
|
11 |
+
--max_eval_samples="64" \
|
12 |
+
--output_dir="./" \
|
13 |
+
--per_device_train_batch_size="32" \
|
14 |
+
--per_device_eval_batch_size="32" \
|
15 |
+
--logging_steps="10" \
|
16 |
+
--logging_first_step \
|
17 |
+
--learning_rate="1e-4" \
|
18 |
+
--warmup_steps="10" \
|
19 |
+
--evaluation_strategy="steps" \
|
20 |
+
--eval_steps="10" \
|
21 |
+
--save_strategy="steps" \
|
22 |
+
--save_steps="10" \
|
23 |
+
--gradient_checkpointing \
|
24 |
+
--fp16 \
|
25 |
+
\
|
26 |
+
--shuffle_buffer_size="20" \
|
27 |
+
--generation_max_length="225" \
|
28 |
+
--max_duration_in_seconds="30" \
|
29 |
+
--text_column_name="sentence" \
|
30 |
+
--freeze_feature_encoder="False" \
|
31 |
+
--report_to="tensorboard" \
|
32 |
+
--metric_for_best_model="wer" \
|
33 |
+
--greater_is_better="False" \
|
34 |
+
--load_best_model_at_end \
|
35 |
+
\
|
36 |
+
--do_train \
|
37 |
+
--do_eval \
|
38 |
+
--ignore_data_skip \
|
39 |
+
--predict_with_generate \
|
40 |
+
--do_normalize_eval \
|
41 |
+
--streaming \
|
42 |
+
--use_auth_token \
|
43 |
+
--push_to_hub \
|
44 |
+
--hub_model_id="ales/whisper-tiny-be-test"
|
run_3/src/belarusian_text_normalizer.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import regex
|
3 |
+
import unicodedata
|
4 |
+
|
5 |
+
from typing import Iterable
|
6 |
+
|
7 |
+
|
8 |
+
class BelarusianTextNormalizer:
|
9 |
+
"""
|
10 |
+
Based on transformers.models.whisper.english_normalizer.BasicTextNormalizer
|
11 |
+
but with support not to remove certain characters.
|
12 |
+
e.g. apostrophe (') - a symbol from Belarusian alphabet - was removed using BasicTextNormalizer.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, split_letters: bool = False):
|
16 |
+
self.split_letters = split_letters
|
17 |
+
self.allowed_symbols = ("'",)
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def clean(s: str, allowed_symbols: Iterable[str] = None):
|
21 |
+
"""
|
22 |
+
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
23 |
+
"""
|
24 |
+
if allowed_symbols is None:
|
25 |
+
allowed_symbols = []
|
26 |
+
res = "".join(" " if unicodedata.category(c)[0] in "MSP" and c not in allowed_symbols else c
|
27 |
+
for c in unicodedata.normalize("NFKC", s))
|
28 |
+
return res
|
29 |
+
|
30 |
+
def __call__(self, s: str):
|
31 |
+
s = s.lower()
|
32 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
33 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
34 |
+
s = self.clean(s, allowed_symbols=self.allowed_symbols).lower()
|
35 |
+
|
36 |
+
if self.split_letters:
|
37 |
+
s = " ".join(regex.findall(r"\X", s, regex.U))
|
38 |
+
|
39 |
+
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
40 |
+
|
41 |
+
return s
|
run_3/src/custom_trainer.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import transformers
|
7 |
+
from transformers import Seq2SeqTrainer
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger('custom_trainer')
|
11 |
+
logger.setLevel(logging.INFO)
|
12 |
+
|
13 |
+
|
14 |
+
class Seq2SeqTrainerCustomLinearScheduler(Seq2SeqTrainer):
|
15 |
+
|
16 |
+
"""
|
17 |
+
Custom trainer to initialize Learning Rate Scheduler
|
18 |
+
and define the learning rate in the end of a training.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def scheduler_n_steps_for_fixed_lr_in_end(lr_max, lr_end, num_train_steps, num_warmup_steps) -> int:
|
23 |
+
assert lr_end < lr_max
|
24 |
+
return num_warmup_steps + (num_train_steps - num_warmup_steps) * lr_max / (lr_max - lr_end)
|
25 |
+
|
26 |
+
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
27 |
+
use_custom_scheduler = False
|
28 |
+
try:
|
29 |
+
# if learning_rate_end was passed as an argument
|
30 |
+
learning_rate_end = self.args.learning_rate_end
|
31 |
+
use_custom_scheduler = True
|
32 |
+
logger.info('TrainerCustomLinearScheduler.create_scheduler(). '
|
33 |
+
f'initializing custom linear scheduler using learning_rate_end={learning_rate_end}')
|
34 |
+
except:
|
35 |
+
logger.info('TrainerCustomLinearScheduler.create_scheduler(). '
|
36 |
+
'learning_rate_end was not set. fallback to a default behavior')
|
37 |
+
|
38 |
+
if use_custom_scheduler is True:
|
39 |
+
scheduler_num_steps = self.scheduler_n_steps_for_fixed_lr_in_end(
|
40 |
+
lr_max=self.args.learning_rate,
|
41 |
+
lr_end=learning_rate_end,
|
42 |
+
num_train_steps=num_training_steps,
|
43 |
+
num_warmup_steps=self.args.warmup_steps
|
44 |
+
)
|
45 |
+
|
46 |
+
self.lr_scheduler = transformers.get_scheduler(
|
47 |
+
'linear', optimizer=optimizer,
|
48 |
+
num_warmup_steps=self.args.warmup_steps,
|
49 |
+
num_training_steps=scheduler_num_steps
|
50 |
+
)
|
51 |
+
return self.lr_scheduler
|
52 |
+
else:
|
53 |
+
return super().create_scheduler(num_training_steps, optimizer)
|
run_3/src/readme.md
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Description
|
2 |
+
|
3 |
+
Fine-tuning [OpenAI Whisper](https://github.com/openai/whisper) model for Belarusian language during
|
4 |
+
[Whisper fine-tuning Event](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event)
|
5 |
+
hosted by HuggingFace x Lambda.
|
6 |
+
|
7 |
+
The code in this repository is a modified version of code from
|
8 |
+
[Whisper fine-tuning Event](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event) repo.
|
9 |
+
|
10 |
+
## Tips:
|
11 |
+
* start with a port worwarding to monitor Tensorboard logs on local computer:
|
12 |
+
```
|
13 |
+
ssh <remote-address> -L <local_port>:localhost:<remote_tensorboard_port>
|
14 |
+
```
|
15 |
+
* Train with redirecting output to a file using `tee`:
|
16 |
+
```
|
17 |
+
source src/run.sh 2>&1 | tee train_run_<run_number>.log
|
18 |
+
```
|
19 |
+
|
20 |
+
## Fine-tuning todos:
|
21 |
+
* logs are printed only right before the evalutaion:<br>
|
22 |
+
```
|
23 |
+
--logging_steps="50"
|
24 |
+
--eval_steps="1000"
|
25 |
+
```
|
26 |
+
* Learning rate:
|
27 |
+
* max learning rate is not the same as LR passed as a parameter to training script. it is actually lower.
|
28 |
+
* when resuming training, LR scheduling behaves incorrectly
|
29 |
+
* check exact sizes of train, eval, test sets of CommonVoice 11
|
30 |
+
* fill TODOs in Notes section with answers and discussions from a Discord
|
31 |
+
|
32 |
+
## Resuming training from exising checkpoint
|
33 |
+
When resuming training from existing checkpoint:
|
34 |
+
* when using streaming, epoch will get reset to 0. that means order of items passed to a model would be the same,
|
35 |
+
if the seed does not change. actual train_dataloader seed would be:
|
36 |
+
`train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)`
|
37 |
+
* it's better to save all `checkpoint-\d+` dirs. better not to rely on data saved to `output_dir` because:
|
38 |
+
* not all data is saved to `output_dir`. e.g. following files are not saved to `output_dir`:
|
39 |
+
`optimizer.pt`, `rng_state.pth`, `scaler.pt`, `scheduler.pt`. so can't resume training in a correct way from
|
40 |
+
data saved to `output_dir`
|
41 |
+
* when resuming training from `output_dir` as a checkpoint dir, model saved to `output_dir` can be worse than
|
42 |
+
previously save (need to investifate further. but such happened already)
|
43 |
+
* learning rate gets reset if passing same parameter value to training script as in previour run.<br>
|
44 |
+
need to provide learning rate from the last step of previous run to continue
|
45 |
+
training in a correct way
|
46 |
+
* however even if passing learning rate from the last step, in the new run it has different value than expected
|
47 |
+
* probably because last checkpont was chosen incorrectly
|
48 |
+
* or learning rate is treated as a starting learning rate at step 0 and not on step X (where we resume).<br>
|
49 |
+
need to try to pass same LR that was passes as a starting LR to the very first run
|
50 |
+
* it's unclear whether decision on saving current model
|
51 |
+
is made by comparing current metrics with metrics of the best checkpoint. I guess model with worse performance
|
52 |
+
will not overwrite best model checkpoint already exising in the output dir, but need to double check.
|
53 |
+
* we can set `ignore_data_skip=True` Training argument not to
|
54 |
+
skip data items already passed to a model - that will save time on data loads.
|
55 |
+
* it's unclear whether order of input items in the train set (that is shuffled) will be the same
|
56 |
+
across multiple reruns - i.e. it's unclear whether sampling is the same across reruns.
|
57 |
+
* if the sampling is the same across reruns, `ignore_data_skip=True` will lead to same items been passed to a model
|
58 |
+
in current run. it's OK if previous run ended with large step value on the last epoch.
|
59 |
+
if not, the same elements from the same epoch will be passed to a model again.
|
60 |
+
|
61 |
+
## Questions:
|
62 |
+
* What checkpoint (best, I guess) is saved in the `output_dir`?
|
63 |
+
How is it overwritten when resuming training from existing checkpoint?
|
64 |
+
* why dataset loading crashes when using `num_proc > 0`?
|
65 |
+
* does `ShuffleCallback` work with StreamingDataset? it reshuffles data `on_epoch_begin()`,
|
66 |
+
but does StreamingDataset have any epochs?
|
67 |
+
* does streaming mode support parallel data load and processing?<br>
|
68 |
+
when using non-streaming mode we can use `dataset.map(..., num_proc=<num_proc>)`
|
69 |
+
* I got CUDA out of memory error when tried to launch a second training run for Whisper Small model.
|
70 |
+
training params are almost the same: `--per_device_train_batch_size="64"`
|
71 |
+
the only thing changed is that now evaluation dataset now doesn't use streaming.
|
72 |
+
|
73 |
+
|
74 |
+
## Notes:
|
75 |
+
* Common Voice 11 dataset
|
76 |
+
[uploaded to HuggingFace](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0)
|
77 |
+
has only single voicing of each sentence in each split (train, validation, test).<br>
|
78 |
+
Much more audiofiles should be available on Common Voice so that each sentence is voiced multiple times by different people
|
79 |
+
* using CommonVoice 11 dataset in a streaming way.<br>
|
80 |
+
use `streaming=True` for train & validation & test.<br>
|
81 |
+
as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
|
82 |
+
but the size of validation and test sets are unknown (need to check).
|
83 |
+
it's likely they are going to be large - thus pre-download of these sets might not reduce
|
84 |
+
overall fine-tuning time compared to streaming mode.
|
85 |
+
* size of train set is ~370'000 audiofiles. if using `batch_size=64`, then
|
86 |
+
1 epoch will have ~5782 steps. <br>
|
87 |
+
Because of `--eval_steps="1000"` will use `--max_steps="6000"` instead of `--max_steps="5800"`
|
88 |
+
to have evaluation metrics computed in the end of training.
|
89 |
+
* if using Google Colab, need to execute `sudo chmod -R 777 .git` inside hf repo to
|
90 |
+
to set right permissions to be able to push trained models to HuggingFace Hub
|
91 |
+
* Log tracking in Jupyter (not working) and in bash (works as expected with `tee`)
|
92 |
+
* Loggers in `run_speech.....py` do not control `transformers` and `datasets` loggers.
|
93 |
+
can't redirect their outputs using handlers. it's better and easier to redirect output in a bash
|
94 |
+
* to evaluate on `google/fleurs` dataset had to downgrade `numba` from `0.56.4` to `0.56.3`, then install `librosa`
|
95 |
+
(strange, because `librosa` should have been installed when `pip install -r ~/whisper-finetuning-be/requirements.txt`
|
96 |
+
was run) and then upgrade back to `numba==0.56.4` because couldn't `import numba` when it was `0.56.3`
|
97 |
+
* Need to set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible
|
98 |
+
* Default Linear scheduler is used
|
99 |
+
* Default Adam optimizer is used
|
100 |
+
|
101 |
+
### Logs not printed when expected
|
102 |
+
* Train logs are printed only before start of a validation.
|
103 |
+
During training they are not printed to a stdout.
|
104 |
+
All worked fine in a Colab.
|
105 |
+
* No progressbar for validation (at least when using streaming and iterable dataset).
|
106 |
+
possible reason is that when using streaming, the dataset len in unknown.
|
107 |
+
* Evaluation metrics get printed to stdout only before the next validation call.
|
108 |
+
All worked fine in a Colab.
|
109 |
+
* Possible reason: usage of `... | tee file.log`. But it's unlikely
|
110 |
+
|
111 |
+
### Text normalization
|
112 |
+
* Whispers BasicTextNormalizer splits words containing apostrophe:
|
113 |
+
```python
|
114 |
+
> from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
115 |
+
> normalizer = BasicTextNormalizer()
|
116 |
+
> normalizer("раз'яднаць")
|
117 |
+
'раз яднаць'
|
118 |
+
```
|
119 |
+
* That's why `BelarusianTextNormalizer` (edited version of `BasicTextNormalizer`) was added to training script:
|
120 |
+
```python
|
121 |
+
> from run_speech_recognition_seq2seq_streaming import BelarusianTextNormalizer
|
122 |
+
> normalizer_be = BelarusianTextNormalizer()
|
123 |
+
> normalizer_be("раз'яднаць")
|
124 |
+
"раз'яднаць"
|
125 |
+
```
|
126 |
+
|
127 |
+
### Different batch sizes for train and evaluation:
|
128 |
+
* Theoretically you can use a larger batch size for evaluation vs training!
|
129 |
+
* Training: we do a forward pass, storing all the activations, and then a backwards pass, storing all the gradients
|
130 |
+
* Inference (evaluation): we only do a forward pass, and don't store any activations
|
131 |
+
* So the memory required for evaluation is much lower than it is for training
|
132 |
+
(we're only doing the forward pass and not storing any values)
|
133 |
+
* In my experience, altering the eval batch size has little effect on eval speed ->
|
134 |
+
I set it to a lower value as this tends to give a more responsive progress bar
|
135 |
+
when evaluating in non-streaming mode (the bar updates faster and more frequently)
|
136 |
+
|
137 |
+
### Slow inference. Long evalutaion compared to training:
|
138 |
+
* Slower inference is an inherent limitation of the sequence-to-sequence architecture.
|
139 |
+
The auto-regressive decoding means that you have to do as many decoder forward passes as tokens generated.
|
140 |
+
* This is much slower than CTC, where you do a single encoder forward pass
|
141 |
+
* Note that 1 evaluation step **will take much longer** than 1 training step, even with the same batch sizes.
|
142 |
+
* With training, we do one forward pass of the encoder, one forward pass of the decoder,
|
143 |
+
one backward pass of the decoder and one backward pass of the encoder (=4 passes total):<br>
|
144 |
+
```
|
145 |
+
audio -> encoder -> decoder -> labels
|
146 |
+
encoder <- decoder <- loss
|
147 |
+
```
|
148 |
+
* During evaluation we do one forward pass of the encoder, and then auto-regressively generate tokens in the decoder.
|
149 |
+
Here, we do as many forward passes of the decoder as tokens generated.
|
150 |
+
So in total, we do one forward pass of the encoder, and N forward passes of the decoder,
|
151 |
+
where N is the number of tokens generated (can be up to the max length, which is 448...).
|
152 |
+
You can see that for 4 or more generated tokens, evaluation is going to be slower than training:<br>
|
153 |
+
```
|
154 |
+
audio -> encoder -> decoder -> decoder -> decoder -> ... -> decoder -> end of sentence token
|
155 |
+
```
|
156 |
+
* I've made a bit of a simplification here in saying that one forward pass
|
157 |
+
takes the same amount of time as one backward pass, but for the purpose of illustrating,
|
158 |
+
this demonstrates the point why evaluation is much slower than training
|
159 |
+
* Essentially it doesn't really matter what you set your eval batch size as we're not aggregating any statistics
|
160 |
+
over the eval batch (in contrast during training we evaluate a true gradient value based on a given batch).
|
161 |
+
* Since we just do a forward pass, we could even run eval with a batch size of 1 and get exactly the same results!
|
162 |
+
* Because we don't get much of an improvement with batch sizes beyond around 8, it's set somewhat arbitrarily
|
163 |
+
|
164 |
+
### Ways to decrease evaluation time during fine-tuning:
|
165 |
+
* reduce `generation_max_length` param:
|
166 |
+
* During training, we can limit the generation max length to a lower number to cut-off the generation
|
167 |
+
after fewer tokens (e.g. 40). This will give worse results during training,
|
168 |
+
but we can still infer the evolution of WER performance over training.
|
169 |
+
* For the final eval step, we can bump up the generation max length back up to 448.
|
170 |
+
* WER performance varies monotonically with generation max length
|
171 |
+
(WER can only stay equal or improve by increasing generation max length),
|
172 |
+
so we know that our final eval WER will be less than (improved) or equal to the WER during training
|
173 |
+
* We can evaluate at less frequent eval_steps: this reduces the number of times we have to perform evaluation
|
174 |
+
|
175 |
+
### Decrease inference time more generally
|
176 |
+
* PyTorch 2.0 and compiling the model could get you a decent speed-up
|
177 |
+
(https://pytorch.org/blog/Accelerating-Hugging-Face-and-TIMM-models/#hugging-face-models)
|
178 |
+
* Downcasting to fp16
|
179 |
+
|
180 |
+
### Memory saving and training larger models:
|
181 |
+
To save memory (and increase either model or batch_size) can experiment with:
|
182 |
+
* using Adafactor instead of Adam.
|
183 |
+
Adam requires two optimiser params per one model param, but Adafactor uses only one.
|
184 |
+
> A word of caution: Adafactor is untested for fine-tuning Whisper,
|
185 |
+
so we are unsure sure how Adafactor performance compares to Adam!
|
186 |
+
* using Adam 8bit from `bitsandbytes` module.
|
187 |
+
need to provide `optim="adamw_bnb_8bit"` param to `Seq2SeqTrainingArguments`
|
188 |
+
* use `deepspeed`. scripts are there in
|
189 |
+
[Whisper fine-tuning Event repo](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event)
|
190 |
+
* load the model and processor in 8bit mode:
|
191 |
+
```python
|
192 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
193 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large", device_map="auto", load_in_8bit=True)
|
194 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-large", load_in_8bit=True)
|
195 |
+
```
|
196 |
+
inference loop:
|
197 |
+
```python
|
198 |
+
for data in dataset:
|
199 |
+
inputs = processor.feature_extractor(data["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features.half().to(device)
|
200 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
201 |
+
predicted_ids = model.generate(inputs, forced_decoder_ids=forced_decoder_ids)
|
202 |
+
text = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=False)[0]
|
203 |
+
print(text)
|
204 |
+
```
|
205 |
+
* 8bit will slower iference compared to full/half-precision
|
206 |
+
* But the memory saving you get is immense (up to 4x vs full-precision).<br>
|
207 |
+
This is the recommended approach when you're limited on VRAM.<br>
|
208 |
+
If you care about inference speed, still to full precision
|
209 |
+
|
210 |
+
### Prepended tokens
|
211 |
+
* Why are there following lines in Data Collator?
|
212 |
+
```python
|
213 |
+
# if bos token is appended in previous tokenization step,
|
214 |
+
# cut bos token here as it's append later anyways
|
215 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
216 |
+
labels = labels[:, 1:]
|
217 |
+
```
|
218 |
+
* `tokenizer.bos_token_id` vs `model.config.decoder_start_token_id`.<br>
|
219 |
+
which one to pass to Data Collator as `decoder_start_token_id` parameter?
|
220 |
+
* Answer:
|
221 |
+
* In this case, the two are equivalent. You can verify this:
|
222 |
+
```python
|
223 |
+
print(tokenizer.bos_token_id)
|
224 |
+
print(model.config.decoder_start_token_id)
|
225 |
+
```
|
226 |
+
|
227 |
+
* Print Output:
|
228 |
+
```
|
229 |
+
<|startoftranscript|>
|
230 |
+
<|startoftranscript|>
|
231 |
+
```
|
232 |
+
|
233 |
+
* Technically speaking, the decoder_start_token_id is the correct convention here. Before starting generating any tokens, we initialise the generate method with a starting token, which is the decoder_start_token_id.
|
234 |
+
See: https://huggingface.co/blog/how-to-generate. The decoder_start_token_id corresponds to the initial context word sequence, and is the zero'th token generated.
|
235 |
+
|
236 |
+
* We remove this token from the encoded labels in the data collator because we always set the zero'th generated token to the decoder_start_token_id. If we leave the decoder_start_token_id as part of the label sequence, then we'll predict the decoder_start_token_id as the zero'th token, and again as the first token! Because we're always forcing it as the zero'th token, we don't need to predict it as the first token, and so we remove it from the target lables
|
237 |
+
|
238 |
+
* These tokens are not forced in the generation process, and so we don't cut them in the data collator. We need to provide them to the model as target labels so that the model can learn the correct tasks from our data
|
239 |
+
|
240 |
+
* The tokens correspond to the audio language, task (translate or transcribe) and whether to predict timestamps
|
241 |
+
|
242 |
+
* We need to tell the model what language the audio corresponds to and what task it's performing during fine-tuning. This way, it learns what audio corresponds to what language, and the difference between transcribing audio vs translating it
|
243 |
+
|
244 |
+
|
run_3/src/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.7
|
2 |
+
torchaudio
|
3 |
+
git+https://github.com/huggingface/transformers
|
4 |
+
git+https://github.com/huggingface/datasets
|
5 |
+
librosa
|
6 |
+
jiwer
|
7 |
+
evaluate>=0.3.0
|
8 |
+
more-itertools
|
9 |
+
tensorboard
|
run_3/src/run_eval_whisper_streaming.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
import datetime
|
5 |
+
import os
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
from transformers import pipeline
|
10 |
+
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
11 |
+
from datasets import load_dataset, Audio
|
12 |
+
import evaluate
|
13 |
+
|
14 |
+
from belarusian_text_normalizer import BelarusianTextNormalizer
|
15 |
+
|
16 |
+
|
17 |
+
now_str = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
logging.basicConfig(
|
22 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
23 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
24 |
+
handlers=[
|
25 |
+
logging.StreamHandler(sys.stdout),
|
26 |
+
logging.FileHandler(filename=f'eval_{now_str}.log', mode='w')
|
27 |
+
],
|
28 |
+
)
|
29 |
+
logger.setLevel(logging.INFO)
|
30 |
+
|
31 |
+
|
32 |
+
wer_metric = evaluate.load("wer")
|
33 |
+
text_normalizer = BelarusianTextNormalizer()
|
34 |
+
|
35 |
+
|
36 |
+
def is_target_text_in_range(ref):
|
37 |
+
if ref.strip() == "ignore time segment in scoring":
|
38 |
+
return False
|
39 |
+
else:
|
40 |
+
return ref.strip() != ""
|
41 |
+
|
42 |
+
|
43 |
+
def normalise(sample, text_column: str):
|
44 |
+
sample["reference_norm"] = text_normalizer(sample[text_column])
|
45 |
+
return sample
|
46 |
+
|
47 |
+
|
48 |
+
def data(dataset,text_column: str):
|
49 |
+
for i, item in enumerate(dataset):
|
50 |
+
yield {**item["audio"], "reference_norm": item["reference_norm"], 'reference': item[text_column]}
|
51 |
+
|
52 |
+
|
53 |
+
def clean_filename(filename: str):
|
54 |
+
return filename.replace(os.path.sep, '_')
|
55 |
+
|
56 |
+
|
57 |
+
def main(args):
|
58 |
+
logger.info(f'running evaluation script with following parameters: {args}')
|
59 |
+
logger.info(f'using following text normalizer: {text_normalizer}')
|
60 |
+
|
61 |
+
batch_size = args.batch_size
|
62 |
+
whisper_asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
63 |
+
|
64 |
+
whisper_asr.model.config.forced_decoder_ids = (
|
65 |
+
whisper_asr.tokenizer.get_decoder_prompt_ids(
|
66 |
+
language=args.language, task="transcribe"
|
67 |
+
)
|
68 |
+
)
|
69 |
+
|
70 |
+
logger.info('loading dataset')
|
71 |
+
dataset = load_dataset(
|
72 |
+
args.dataset,
|
73 |
+
args.config,
|
74 |
+
split=args.split,
|
75 |
+
streaming=args.streaming,
|
76 |
+
use_auth_token=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
# Only uncomment for debugging
|
80 |
+
dataset = dataset.take(args.max_eval_samples)
|
81 |
+
|
82 |
+
# TODO: probably no need in cast, because pipelien migh handle resampling internally. need to check
|
83 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
84 |
+
dataset = dataset.map(normalise, fn_kwargs=dict(text_column=args.text_column))
|
85 |
+
dataset = dataset.filter(is_target_text_in_range, input_columns=["reference_norm"])
|
86 |
+
|
87 |
+
predictions = []
|
88 |
+
predictions_norm = []
|
89 |
+
references = []
|
90 |
+
references_norm = []
|
91 |
+
audio_paths = []
|
92 |
+
|
93 |
+
logger.info('running inference')
|
94 |
+
for out in whisper_asr(data(dataset, text_column=args.text_column), batch_size=batch_size):
|
95 |
+
predictions.append(out["text"])
|
96 |
+
predictions_norm.append(text_normalizer(out["text"]))
|
97 |
+
references.append(out["reference"][0])
|
98 |
+
references_norm.append(out["reference_norm"][0])
|
99 |
+
audio_paths.append(out['path'][0])
|
100 |
+
|
101 |
+
logger.info('computing metrics')
|
102 |
+
wer = wer_metric.compute(references=references_norm, predictions=predictions_norm)
|
103 |
+
wer = wer * 100
|
104 |
+
|
105 |
+
logger.info('metrics computed')
|
106 |
+
logger.info(f'WER: {wer}')
|
107 |
+
|
108 |
+
if args.save_predictions is True:
|
109 |
+
preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.tsv'
|
110 |
+
preds_fp = clean_filename(preds_fp)
|
111 |
+
logger.info(f'saving predictions to: "{preds_fp}"')
|
112 |
+
preds_df = pd.DataFrame({
|
113 |
+
'audio_path': audio_paths,
|
114 |
+
'prediction_norm': predictions_norm, 'reference_norm': references_norm,
|
115 |
+
'prediction': predictions, 'reference': references,
|
116 |
+
})
|
117 |
+
preds_df.to_csv(preds_fp, sep='\t', index=False)
|
118 |
+
else:
|
119 |
+
logger.info('save_predictions is False. will not save predictions to a file')
|
120 |
+
|
121 |
+
if args.push_to_hub is True:
|
122 |
+
logger.info(f'updating model card and pushing to HuggingFace Hub')
|
123 |
+
evaluate.push_to_hub(
|
124 |
+
model_id=args.model_id,
|
125 |
+
|
126 |
+
metric_value=wer,
|
127 |
+
metric_type="wer",
|
128 |
+
metric_name="WER",
|
129 |
+
|
130 |
+
dataset_name=args.dataset,
|
131 |
+
dataset_type=args.dataset,
|
132 |
+
dataset_config=args.config,
|
133 |
+
dataset_split=args.split,
|
134 |
+
|
135 |
+
task_type="automatic-speech-recognition",
|
136 |
+
task_name="Automatic Speech Recognition"
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
logger.info('push_to_hub is False. will not update model card and push to HuggingFace Hub')
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
parser = argparse.ArgumentParser()
|
144 |
+
|
145 |
+
parser.add_argument(
|
146 |
+
"--model_id",
|
147 |
+
type=str,
|
148 |
+
required=True,
|
149 |
+
help="Model identifier. Should be loadable with 🤗 Transformers",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--dataset",
|
153 |
+
type=str,
|
154 |
+
default="mozilla-foundation/common_voice_11_0",
|
155 |
+
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--config",
|
159 |
+
type=str,
|
160 |
+
required=True,
|
161 |
+
help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--split",
|
165 |
+
type=str,
|
166 |
+
default="test",
|
167 |
+
help="Split of the dataset. *E.g.* `'test'`",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--text_column",
|
171 |
+
type=str,
|
172 |
+
required=True,
|
173 |
+
help="Dataset column name containing target transcription of an audiofile"
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--device",
|
177 |
+
type=int,
|
178 |
+
default=-1,
|
179 |
+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--batch_size",
|
183 |
+
type=int,
|
184 |
+
default=16,
|
185 |
+
help="Number of samples to go through each streamed batch.",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--max_eval_samples",
|
189 |
+
type=int,
|
190 |
+
default=None,
|
191 |
+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--streaming",
|
195 |
+
type=bool,
|
196 |
+
default=True,
|
197 |
+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--language",
|
201 |
+
type=str,
|
202 |
+
required=True,
|
203 |
+
help="Two letter language code for the transcription language, e.g. use 'en' for English.",
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
'--push_to_hub',
|
207 |
+
type=bool,
|
208 |
+
default=True,
|
209 |
+
help="Whether to update model card and push changes to HuggingFace Hub"
|
210 |
+
)
|
211 |
+
parser.add_argument(
|
212 |
+
'--save_predictions',
|
213 |
+
type=bool,
|
214 |
+
default=True,
|
215 |
+
help="Whether to store predictions and target transcriptions to a file"
|
216 |
+
)
|
217 |
+
args = parser.parse_args()
|
218 |
+
|
219 |
+
main(args)
|
run_3/src/run_speech_recognition_seq2seq_streaming.py
ADDED
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for sequence to sequence speech recognition
|
18 |
+
with 🤗 Datasets' streaming mode.
|
19 |
+
"""
|
20 |
+
# You can also adapt this script for your own sequence to sequence speech
|
21 |
+
# recognition task. Pointers for this are left as comments.
|
22 |
+
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import sys
|
26 |
+
import datetime
|
27 |
+
from dataclasses import dataclass, field
|
28 |
+
from typing import Any, Dict, List, Optional, Union, Iterable
|
29 |
+
|
30 |
+
import datasets
|
31 |
+
import torch
|
32 |
+
from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
|
33 |
+
from torch.utils.data import IterableDataset
|
34 |
+
|
35 |
+
import evaluate
|
36 |
+
import transformers
|
37 |
+
from transformers import (
|
38 |
+
AutoConfig,
|
39 |
+
AutoFeatureExtractor,
|
40 |
+
AutoModelForSpeechSeq2Seq,
|
41 |
+
AutoProcessor,
|
42 |
+
AutoTokenizer,
|
43 |
+
HfArgumentParser,
|
44 |
+
Seq2SeqTrainer,
|
45 |
+
Seq2SeqTrainingArguments,
|
46 |
+
TrainerCallback,
|
47 |
+
set_seed,
|
48 |
+
)
|
49 |
+
from transformers.trainer_pt_utils import IterableDatasetShard
|
50 |
+
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
51 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
52 |
+
from transformers.utils.versions import require_version
|
53 |
+
|
54 |
+
from custom_trainer import Seq2SeqTrainerCustomLinearScheduler
|
55 |
+
from belarusian_text_normalizer import BelarusianTextNormalizer
|
56 |
+
|
57 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
58 |
+
check_min_version("4.25.0.dev0")
|
59 |
+
|
60 |
+
require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
61 |
+
|
62 |
+
logger = logging.getLogger(__name__)
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class CustomTrainingArguments:
|
67 |
+
""" Custom trianing arguments """
|
68 |
+
|
69 |
+
learning_rate_end: Optional[float] = field(
|
70 |
+
default=None,
|
71 |
+
metadata={
|
72 |
+
"help": ('Learning rate in the end of a training run. Passed to a Seq2SeqTrainerCustomLinearScheduler.')
|
73 |
+
},
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class ModelArguments:
|
79 |
+
"""
|
80 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
81 |
+
"""
|
82 |
+
|
83 |
+
model_name_or_path: str = field(
|
84 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
85 |
+
)
|
86 |
+
config_name: Optional[str] = field(
|
87 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
88 |
+
)
|
89 |
+
tokenizer_name: Optional[str] = field(
|
90 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
91 |
+
)
|
92 |
+
feature_extractor_name: Optional[str] = field(
|
93 |
+
default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
|
94 |
+
)
|
95 |
+
cache_dir: Optional[str] = field(
|
96 |
+
default=None,
|
97 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
98 |
+
)
|
99 |
+
use_fast_tokenizer: bool = field(
|
100 |
+
default=True,
|
101 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
102 |
+
)
|
103 |
+
model_revision: str = field(
|
104 |
+
default="main",
|
105 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
106 |
+
)
|
107 |
+
use_auth_token: bool = field(
|
108 |
+
default=False,
|
109 |
+
metadata={
|
110 |
+
"help": (
|
111 |
+
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
112 |
+
"with private models)."
|
113 |
+
)
|
114 |
+
},
|
115 |
+
)
|
116 |
+
freeze_feature_encoder: bool = field(
|
117 |
+
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
118 |
+
)
|
119 |
+
freeze_encoder: bool = field(
|
120 |
+
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
|
121 |
+
)
|
122 |
+
forced_decoder_ids: List[List[int]] = field(
|
123 |
+
default=None,
|
124 |
+
metadata={
|
125 |
+
"help": (
|
126 |
+
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
|
127 |
+
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
|
128 |
+
"will always be a token of index 123."
|
129 |
+
)
|
130 |
+
},
|
131 |
+
)
|
132 |
+
suppress_tokens: List[int] = field(
|
133 |
+
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
|
134 |
+
)
|
135 |
+
model_index_name: str = field(default=None, metadata={"help": "Pretty name for the model card."})
|
136 |
+
|
137 |
+
|
138 |
+
@dataclass
|
139 |
+
class DataTrainingArguments:
|
140 |
+
"""
|
141 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
142 |
+
"""
|
143 |
+
|
144 |
+
dataset_name: str = field(
|
145 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
146 |
+
)
|
147 |
+
dataset_config_name: Optional[str] = field(
|
148 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
149 |
+
)
|
150 |
+
max_train_samples: Optional[int] = field(
|
151 |
+
default=None,
|
152 |
+
metadata={
|
153 |
+
"help": (
|
154 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
155 |
+
"value if set."
|
156 |
+
)
|
157 |
+
},
|
158 |
+
)
|
159 |
+
max_eval_samples: Optional[int] = field(
|
160 |
+
default=None,
|
161 |
+
metadata={
|
162 |
+
"help": (
|
163 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
164 |
+
"value if set."
|
165 |
+
)
|
166 |
+
},
|
167 |
+
)
|
168 |
+
audio_column_name: str = field(
|
169 |
+
default="audio",
|
170 |
+
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
171 |
+
)
|
172 |
+
text_column_name: str = field(
|
173 |
+
default="text",
|
174 |
+
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
|
175 |
+
)
|
176 |
+
max_duration_in_seconds: float = field(
|
177 |
+
default=20.0,
|
178 |
+
metadata={
|
179 |
+
"help": (
|
180 |
+
"Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
|
181 |
+
" 'max_duration_in_seconds`"
|
182 |
+
)
|
183 |
+
},
|
184 |
+
)
|
185 |
+
min_duration_in_seconds: float = field(
|
186 |
+
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
187 |
+
)
|
188 |
+
train_split_name: str = field(
|
189 |
+
default="train",
|
190 |
+
metadata={
|
191 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
192 |
+
},
|
193 |
+
)
|
194 |
+
eval_split_name: str = field(
|
195 |
+
default="test",
|
196 |
+
metadata={
|
197 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
198 |
+
},
|
199 |
+
)
|
200 |
+
do_lower_case: bool = field(
|
201 |
+
default=False,
|
202 |
+
metadata={"help": "Whether the target text should be lower cased."},
|
203 |
+
)
|
204 |
+
do_remove_punctuation: bool = field(
|
205 |
+
default=False,
|
206 |
+
metadata={"help": "Whether the target text should be striped of punctuation."},
|
207 |
+
)
|
208 |
+
do_normalize_eval: bool = field(
|
209 |
+
default=True,
|
210 |
+
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
211 |
+
)
|
212 |
+
language: str = field(
|
213 |
+
default=None,
|
214 |
+
metadata={
|
215 |
+
"help": (
|
216 |
+
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
|
217 |
+
"only. For English speech recognition, it should be set to `None`."
|
218 |
+
)
|
219 |
+
},
|
220 |
+
)
|
221 |
+
task: str = field(
|
222 |
+
default="transcribe",
|
223 |
+
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
|
224 |
+
)
|
225 |
+
shuffle_buffer_size: Optional[int] = field(
|
226 |
+
default=500,
|
227 |
+
metadata={
|
228 |
+
"help": (
|
229 |
+
"The number of streamed examples to download before shuffling them. The large the buffer, "
|
230 |
+
"the closer it is to real offline shuffling."
|
231 |
+
)
|
232 |
+
},
|
233 |
+
)
|
234 |
+
streaming_train: bool = field(
|
235 |
+
default=True,
|
236 |
+
metadata={"help": "Whether to use streaming mode to load and pre-process the train split."},
|
237 |
+
)
|
238 |
+
streaming_eval: bool = field(
|
239 |
+
default=True,
|
240 |
+
metadata={"help": "Whether to use streaming mode to load and pre-process the evaluation split."},
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
@dataclass
|
246 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
247 |
+
"""
|
248 |
+
Data collator that will dynamically pad the inputs received.
|
249 |
+
Args:
|
250 |
+
processor ([`WhisperProcessor`])
|
251 |
+
The processor used for processing the data.
|
252 |
+
decoder_start_token_id (`int`)
|
253 |
+
The begin-of-sentence of the decoder.
|
254 |
+
"""
|
255 |
+
|
256 |
+
processor: Any
|
257 |
+
decoder_start_token_id: int
|
258 |
+
|
259 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
260 |
+
# split inputs and labels since they have to be of different lengths and need
|
261 |
+
# different padding methods
|
262 |
+
model_input_name = self.processor.model_input_names[0]
|
263 |
+
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
264 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
265 |
+
|
266 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
267 |
+
|
268 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
269 |
+
|
270 |
+
# replace padding with -100 to ignore loss correctly
|
271 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
272 |
+
|
273 |
+
# if bos token is appended in previous tokenization step,
|
274 |
+
# cut bos token here as it's append later anyways
|
275 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
276 |
+
labels = labels[:, 1:]
|
277 |
+
|
278 |
+
batch["labels"] = labels
|
279 |
+
|
280 |
+
return batch
|
281 |
+
|
282 |
+
|
283 |
+
def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
|
284 |
+
"""
|
285 |
+
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
286 |
+
each split is loaded individually and then splits combined by taking alternating examples from
|
287 |
+
each (interleaving).
|
288 |
+
"""
|
289 |
+
if "+" in split:
|
290 |
+
# load multiple splits separated by the `+` symbol with streaming mode
|
291 |
+
dataset_splits = [
|
292 |
+
load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
293 |
+
for split_name in split.split("+")
|
294 |
+
]
|
295 |
+
# interleave multiple splits to form one dataset
|
296 |
+
interleaved_dataset = interleave_datasets(dataset_splits)
|
297 |
+
return interleaved_dataset
|
298 |
+
else:
|
299 |
+
# load a single split *with* streaming mode
|
300 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
|
301 |
+
return dataset
|
302 |
+
|
303 |
+
|
304 |
+
def main():
|
305 |
+
# 1. Parse input arguments
|
306 |
+
# See all possible arguments in src/transformers/training_args.py
|
307 |
+
# or by passing the --help flag to this script.
|
308 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
309 |
+
parser = HfArgumentParser((
|
310 |
+
ModelArguments, DataTrainingArguments,
|
311 |
+
Seq2SeqTrainingArguments, CustomTrainingArguments
|
312 |
+
))
|
313 |
+
|
314 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
315 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
316 |
+
# let's parse it to get our arguments.
|
317 |
+
model_args, data_args, training_args, custom_training_args = parser.parse_json_file(
|
318 |
+
json_file=os.path.abspath(sys.argv[1])
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
model_args, data_args, training_args, custom_training_args = parser.parse_args_into_dataclasses()
|
322 |
+
|
323 |
+
|
324 |
+
# 2. Setup logging
|
325 |
+
now_str = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
|
326 |
+
logging.basicConfig(
|
327 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
328 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
329 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
330 |
+
)
|
331 |
+
log_level = training_args.get_process_log_level()
|
332 |
+
logger.setLevel(log_level)
|
333 |
+
datasets.utils.logging.set_verbosity(log_level)
|
334 |
+
transformers.utils.logging.set_verbosity(log_level)
|
335 |
+
transformers.utils.logging.enable_default_handler()
|
336 |
+
transformers.utils.logging.enable_explicit_format()
|
337 |
+
|
338 |
+
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
339 |
+
|
340 |
+
# Log on each process the small summary:
|
341 |
+
logger.warning(
|
342 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
343 |
+
f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
344 |
+
)
|
345 |
+
|
346 |
+
# update training_args if needed
|
347 |
+
if custom_training_args.learning_rate_end is not None:
|
348 |
+
logger.info(f'found learning_rate_end={custom_training_args.learning_rate_end} in passed arguments. '
|
349 |
+
'will pass it to training_args')
|
350 |
+
training_args.learning_rate_end = custom_training_args.learning_rate_end
|
351 |
+
else:
|
352 |
+
logger.info(f'learning_rate_end is None. will not pass it to training_args')
|
353 |
+
|
354 |
+
# log arguments
|
355 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
356 |
+
logger.info(f"Data parameters: {data_args}")
|
357 |
+
logger.info(f"Model parameters: {model_args}")
|
358 |
+
|
359 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
360 |
+
if is_main_process(training_args.local_rank):
|
361 |
+
transformers.utils.logging.set_verbosity_info()
|
362 |
+
|
363 |
+
|
364 |
+
# 3. Detecting last checkpoint and eventually continue from last checkpoint
|
365 |
+
last_checkpoint = None
|
366 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
367 |
+
logger.info(f'output_dir already exists. will try to load last checkpoint.')
|
368 |
+
|
369 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
370 |
+
if last_checkpoint is not None:
|
371 |
+
if training_args.resume_from_checkpoint is None:
|
372 |
+
logger.info(
|
373 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
374 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
logger.info(f'Last checkpoint found at: {last_checkpoint}. Will ignore it and resume training '
|
378 |
+
f'from passed resume_from_checkpoint param: {training_args.resume_from_checkpoint}')
|
379 |
+
assert os.path.isdir(training_args.resume_from_checkpoint)
|
380 |
+
else:
|
381 |
+
logger.info('last_checkpoint is None. will try to read from training_args.resume_from_checkpoint')
|
382 |
+
|
383 |
+
if training_args.resume_from_checkpoint is not None and os.path.isdir(training_args.resume_from_checkpoint):
|
384 |
+
logger.info(f'Will resume training from passed resume_from_checkpoint param: '
|
385 |
+
f'{training_args.resume_from_checkpoint}')
|
386 |
+
else:
|
387 |
+
logger.info('last_checkpoint is None. resume_from_checkpoint is either None or not existing dir. '
|
388 |
+
'will try to read from the model saved in the root of output_dir.')
|
389 |
+
|
390 |
+
dir_content = os.listdir(training_args.output_dir)
|
391 |
+
if len(dir_content) == 0:
|
392 |
+
logger.info('output_dir is empty. will start training from scratch.')
|
393 |
+
else:
|
394 |
+
model_fn = 'pytorch_model.bin'
|
395 |
+
if model_fn in dir_content:
|
396 |
+
logger.info(f'found {model_fn} inside output_dir. '
|
397 |
+
f'will continue training treating output_dir as a last checkpoint.')
|
398 |
+
last_checkpoint = training_args.output_dir
|
399 |
+
else:
|
400 |
+
allowed_dirs = ['.git', '.gitattributes', 'src']
|
401 |
+
unexpected_content = set(dir_content).difference(allowed_dirs)
|
402 |
+
unexpected_content = [x for x in unexpected_content
|
403 |
+
if not x.endswith('.log') and os.path.isfile(x)]
|
404 |
+
if len(unexpected_content) > 0:
|
405 |
+
raise ValueError(
|
406 |
+
f'Could not find last_checkpoint, resume_from_checkpoint is either None '
|
407 |
+
'or not existing dir, output_dir is non-empty but does not contain a model.'
|
408 |
+
'Use --overwrite_output_dir to overcome. '
|
409 |
+
f'unexpected_content: {unexpected_content}'
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
logger.info(f'dir is not empty, but contains only: {dir_content}. '
|
413 |
+
'it is OK - will start training')
|
414 |
+
|
415 |
+
|
416 |
+
# Set seed before initializing model.
|
417 |
+
set_seed(training_args.seed)
|
418 |
+
|
419 |
+
|
420 |
+
# 4. Load dataset
|
421 |
+
|
422 |
+
# TODO: replace dataset dicts with single key to IterableDataset and to Dataset.
|
423 |
+
# don't know how to do it know - using dict simply because they work.
|
424 |
+
raw_train = IterableDatasetDict() if data_args.streaming_train else DatasetDict()
|
425 |
+
raw_eval = IterableDatasetDict() if data_args.streaming_eval else DatasetDict()
|
426 |
+
|
427 |
+
if training_args.do_train:
|
428 |
+
raw_train['train'] = load_maybe_streaming_dataset(
|
429 |
+
data_args.dataset_name,
|
430 |
+
data_args.dataset_config_name,
|
431 |
+
split=data_args.train_split_name,
|
432 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
433 |
+
streaming=data_args.streaming_train,
|
434 |
+
)
|
435 |
+
|
436 |
+
if training_args.do_eval:
|
437 |
+
raw_eval['eval'] = load_maybe_streaming_dataset(
|
438 |
+
data_args.dataset_name,
|
439 |
+
data_args.dataset_config_name,
|
440 |
+
split=data_args.eval_split_name,
|
441 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
442 |
+
streaming=data_args.streaming_eval,
|
443 |
+
)
|
444 |
+
|
445 |
+
raw_datasets_features = list(next(iter(raw_train.values())).features.keys())
|
446 |
+
|
447 |
+
if data_args.audio_column_name not in raw_datasets_features:
|
448 |
+
raise ValueError(
|
449 |
+
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
450 |
+
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
451 |
+
f"{', '.join(raw_datasets_features)}."
|
452 |
+
)
|
453 |
+
|
454 |
+
if data_args.text_column_name not in raw_datasets_features:
|
455 |
+
raise ValueError(
|
456 |
+
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
457 |
+
"Make sure to set `--text_column_name` to the correct text column - one of "
|
458 |
+
f"{', '.join(raw_datasets_features)}."
|
459 |
+
)
|
460 |
+
|
461 |
+
|
462 |
+
# 5. Load pretrained model, tokenizer, and feature extractor
|
463 |
+
#
|
464 |
+
# Distributed training:
|
465 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
466 |
+
config = AutoConfig.from_pretrained(
|
467 |
+
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
468 |
+
cache_dir=model_args.cache_dir,
|
469 |
+
revision=model_args.model_revision,
|
470 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
471 |
+
)
|
472 |
+
|
473 |
+
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
474 |
+
|
475 |
+
if training_args.gradient_checkpointing:
|
476 |
+
config.update({"use_cache": False})
|
477 |
+
|
478 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
479 |
+
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
480 |
+
cache_dir=model_args.cache_dir,
|
481 |
+
revision=model_args.model_revision,
|
482 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
483 |
+
)
|
484 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
485 |
+
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
486 |
+
cache_dir=model_args.cache_dir,
|
487 |
+
use_fast=model_args.use_fast_tokenizer,
|
488 |
+
revision=model_args.model_revision,
|
489 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
490 |
+
)
|
491 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
492 |
+
model_args.model_name_or_path,
|
493 |
+
config=config,
|
494 |
+
cache_dir=model_args.cache_dir,
|
495 |
+
revision=model_args.model_revision,
|
496 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
497 |
+
)
|
498 |
+
|
499 |
+
if model.config.decoder_start_token_id is None:
|
500 |
+
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
501 |
+
|
502 |
+
if model_args.freeze_feature_encoder:
|
503 |
+
model.freeze_feature_encoder()
|
504 |
+
|
505 |
+
if model_args.freeze_encoder:
|
506 |
+
model.freeze_encoder()
|
507 |
+
|
508 |
+
if data_args.language is not None:
|
509 |
+
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
510 |
+
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
511 |
+
|
512 |
+
|
513 |
+
# 6. Explicitly resample speech dataset
|
514 |
+
raw_train = raw_train.cast_column(
|
515 |
+
data_args.audio_column_name, datasets.features.Audio(
|
516 |
+
sampling_rate=feature_extractor.sampling_rate,
|
517 |
+
mono=True
|
518 |
+
)
|
519 |
+
)
|
520 |
+
raw_eval = raw_eval.cast_column(
|
521 |
+
data_args.audio_column_name, datasets.features.Audio(
|
522 |
+
sampling_rate=feature_extractor.sampling_rate,
|
523 |
+
mono=True
|
524 |
+
)
|
525 |
+
)
|
526 |
+
|
527 |
+
|
528 |
+
# 7. Preprocessing the datasets.
|
529 |
+
# We need to read the audio files as arrays and tokenize the targets.
|
530 |
+
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
531 |
+
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
532 |
+
max_labels_length = 448 # model.config.max_length
|
533 |
+
|
534 |
+
audio_column_name = data_args.audio_column_name
|
535 |
+
text_column_name = data_args.text_column_name
|
536 |
+
model_input_name = feature_extractor.model_input_names[0]
|
537 |
+
do_lower_case = data_args.do_lower_case
|
538 |
+
do_remove_punctuation = data_args.do_remove_punctuation
|
539 |
+
normalizer = BelarusianTextNormalizer() # custom normalizer based on 'official' text normalizer from OpenAI
|
540 |
+
|
541 |
+
if data_args.max_train_samples is not None:
|
542 |
+
raw_train['train'] = (
|
543 |
+
raw_train['train'].take(data_args.max_train_samples)
|
544 |
+
if data_args.streaming_train
|
545 |
+
else raw_train['train'].select(range(data_args.max_train_samples))
|
546 |
+
)
|
547 |
+
|
548 |
+
if data_args.max_eval_samples is not None:
|
549 |
+
raw_eval['eval'] = (
|
550 |
+
raw_eval['eval'].take(data_args.max_eval_samples)
|
551 |
+
if data_args.streaming_eval
|
552 |
+
else raw_eval['eval'].select(range(data_args.max_eval_samples))
|
553 |
+
)
|
554 |
+
|
555 |
+
def prepare_dataset(sample, labels_max_len: int = None):
|
556 |
+
# process audio
|
557 |
+
audio = sample[audio_column_name]
|
558 |
+
inputs = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"])
|
559 |
+
# process audio length
|
560 |
+
sample[model_input_name] = inputs.get(model_input_name)[0]
|
561 |
+
sample["input_length"] = len(audio["array"])
|
562 |
+
|
563 |
+
# process targets
|
564 |
+
input_str = sample[text_column_name].lower() if do_lower_case else sample[text_column_name]
|
565 |
+
if do_remove_punctuation:
|
566 |
+
input_str = normalizer(input_str).strip()
|
567 |
+
sample['labels'] = tokenizer(input_str).input_ids
|
568 |
+
sample['labels_length'] = len(sample['labels']) # include special characters
|
569 |
+
|
570 |
+
sample['labels_truncated'] = 0
|
571 |
+
# need to truncate validation and test labels that are longer that model.config.max_length.
|
572 |
+
# can't drop such examples because this will affect validation and test scores.
|
573 |
+
# thus need to truncate.
|
574 |
+
if labels_max_len is not None:
|
575 |
+
if len(sample['labels']) > labels_max_len:
|
576 |
+
sample['labels'] = sample['labels'][:labels_max_len]
|
577 |
+
sample['labels_truncated'] = 1
|
578 |
+
|
579 |
+
return sample
|
580 |
+
|
581 |
+
with training_args.main_process_first(desc="dataset map pre-processing"):
|
582 |
+
logger.info(f'vectorizing dataset')
|
583 |
+
|
584 |
+
# TODO: replace dataset dicts with single key to IterableDataset and to Dataset.
|
585 |
+
# don't know how to do it know - using dict simply because they work.
|
586 |
+
vectorized_train = IterableDatasetDict() if data_args.streaming_train else DatasetDict()
|
587 |
+
vectorized_eval = IterableDatasetDict() if data_args.streaming_eval else DatasetDict()
|
588 |
+
|
589 |
+
num_proc = None
|
590 |
+
if data_args.streaming_train or data_args.streaming_eval:
|
591 |
+
logger.info(f'will preprocess data using {num_proc} processes.')
|
592 |
+
|
593 |
+
if data_args.streaming_train:
|
594 |
+
vectorized_train['train'] = raw_train['train'].map(
|
595 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
596 |
+
fn_kwargs=dict(labels_max_len=None),
|
597 |
+
).with_format("torch")
|
598 |
+
else:
|
599 |
+
vectorized_train['train'] = raw_train['train'].map(
|
600 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
601 |
+
num_proc=num_proc,
|
602 |
+
fn_kwargs=dict(labels_max_len=None),
|
603 |
+
).with_format("torch")
|
604 |
+
|
605 |
+
if data_args.streaming_eval:
|
606 |
+
vectorized_eval['eval'] = raw_eval['eval'].map(
|
607 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
608 |
+
fn_kwargs=dict(labels_max_len=max_labels_length),
|
609 |
+
).with_format("torch")
|
610 |
+
else:
|
611 |
+
vectorized_eval['eval'] = raw_eval['eval'].map(
|
612 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
613 |
+
num_proc=num_proc,
|
614 |
+
fn_kwargs=dict(labels_max_len=max_labels_length),
|
615 |
+
).with_format("torch")
|
616 |
+
|
617 |
+
if training_args.do_train and data_args.streaming_train:
|
618 |
+
# manually shuffle if streaming (done by the trainer for non-streaming)
|
619 |
+
vectorized_train['train'] = vectorized_train['train'].shuffle(
|
620 |
+
buffer_size=data_args.shuffle_buffer_size,
|
621 |
+
seed=training_args.seed,
|
622 |
+
)
|
623 |
+
|
624 |
+
# Filter training data that is shorter than min_input_length or longer than max_input_length.
|
625 |
+
# Drop items with labels longer that max model length.
|
626 |
+
# Drop such items from the train set only. Should keep them in eval set not to affect eval metrics.
|
627 |
+
def is_audio_in_length_range(length):
|
628 |
+
return min_input_length < length < max_input_length
|
629 |
+
|
630 |
+
def are_labels_in_length_range(labels_length):
|
631 |
+
return labels_length <= max_labels_length
|
632 |
+
|
633 |
+
if training_args.do_train:
|
634 |
+
# Filter items from train set only.
|
635 |
+
# Should keep them in eval set not to affect eval metrics.
|
636 |
+
vectorized_train['train'] = vectorized_train['train'].filter(
|
637 |
+
is_audio_in_length_range,
|
638 |
+
input_columns=["input_length"],
|
639 |
+
)
|
640 |
+
vectorized_train['train'] = vectorized_train['train'].filter(
|
641 |
+
are_labels_in_length_range,
|
642 |
+
input_columns=["labels_length"],
|
643 |
+
)
|
644 |
+
|
645 |
+
|
646 |
+
# 8. Load Metric
|
647 |
+
metric = evaluate.load("wer")
|
648 |
+
do_normalize_eval = data_args.do_normalize_eval
|
649 |
+
|
650 |
+
def compute_metrics(pred):
|
651 |
+
pred_ids = pred.predictions
|
652 |
+
|
653 |
+
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
654 |
+
|
655 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
656 |
+
# we do not want to group tokens when computing the metrics
|
657 |
+
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
658 |
+
|
659 |
+
if do_normalize_eval:
|
660 |
+
pred_str = [normalizer(pred) for pred in pred_str]
|
661 |
+
label_str = [normalizer(label) for label in label_str]
|
662 |
+
# filtering step to only evaluate the samples that correspond to non-zero references:
|
663 |
+
pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
|
664 |
+
label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
|
665 |
+
|
666 |
+
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
667 |
+
|
668 |
+
return {"wer": wer}
|
669 |
+
|
670 |
+
|
671 |
+
# 9. Create a single speech processor
|
672 |
+
if is_main_process(training_args.local_rank):
|
673 |
+
# save feature extractor, tokenizer and config
|
674 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
675 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
676 |
+
config.save_pretrained(training_args.output_dir)
|
677 |
+
|
678 |
+
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
679 |
+
|
680 |
+
|
681 |
+
# 10. Define data collator
|
682 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
683 |
+
processor=processor,
|
684 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
685 |
+
)
|
686 |
+
|
687 |
+
|
688 |
+
# 11. Configure Trainer
|
689 |
+
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
690 |
+
# Only required for streaming: Trainer automatically shuffles non-streaming datasets
|
691 |
+
class ShuffleCallback(TrainerCallback):
|
692 |
+
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
|
693 |
+
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
694 |
+
pass # set_epoch() is handled by the Trainer
|
695 |
+
elif isinstance(train_dataloader.dataset, IterableDataset):
|
696 |
+
logger.info(f'ShuffleCallback. shuffling train dataset. '
|
697 |
+
f'seed: {training_args.seed}. dataset epoch: {train_dataloader.dataset._epoch}')
|
698 |
+
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
699 |
+
|
700 |
+
# Initialize Trainer
|
701 |
+
trainer = Seq2SeqTrainerCustomLinearScheduler(
|
702 |
+
model=model,
|
703 |
+
args=training_args,
|
704 |
+
train_dataset=vectorized_train['train'] if training_args.do_train else None,
|
705 |
+
eval_dataset=vectorized_eval['eval'] if training_args.do_eval else None,
|
706 |
+
tokenizer=processor,
|
707 |
+
data_collator=data_collator,
|
708 |
+
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
709 |
+
callbacks=[ShuffleCallback()] if data_args.streaming_train else None,
|
710 |
+
)
|
711 |
+
|
712 |
+
|
713 |
+
# 12. Training
|
714 |
+
if training_args.do_train:
|
715 |
+
checkpoint = None
|
716 |
+
if training_args.resume_from_checkpoint is not None:
|
717 |
+
checkpoint = training_args.resume_from_checkpoint
|
718 |
+
elif last_checkpoint is not None:
|
719 |
+
checkpoint = last_checkpoint
|
720 |
+
logger.info(f'will launch training and pass resume_from_checkpoint={checkpoint}')
|
721 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
722 |
+
trainer.save_model() # Saves the feature extractor too for easy upload
|
723 |
+
|
724 |
+
metrics = train_result.metrics
|
725 |
+
if data_args.max_train_samples:
|
726 |
+
metrics["train_samples"] = data_args.max_train_samples
|
727 |
+
trainer.log_metrics("train", metrics)
|
728 |
+
trainer.save_metrics("train", metrics)
|
729 |
+
trainer.save_state()
|
730 |
+
|
731 |
+
|
732 |
+
# 13. Evaluation
|
733 |
+
results = {}
|
734 |
+
if training_args.do_eval:
|
735 |
+
logger.info("*** Evaluate ***")
|
736 |
+
metrics = trainer.evaluate(
|
737 |
+
metric_key_prefix="eval",
|
738 |
+
max_length=training_args.generation_max_length,
|
739 |
+
num_beams=training_args.generation_num_beams,
|
740 |
+
)
|
741 |
+
if data_args.max_eval_samples:
|
742 |
+
metrics["eval_samples"] = data_args.max_eval_samples
|
743 |
+
|
744 |
+
trainer.log_metrics("eval", metrics)
|
745 |
+
trainer.save_metrics("eval", metrics)
|
746 |
+
|
747 |
+
|
748 |
+
# 14. Write Training Stats
|
749 |
+
kwargs = {
|
750 |
+
"finetuned_from": model_args.model_name_or_path,
|
751 |
+
"tasks": "automatic-speech-recognition",
|
752 |
+
"tags": "whisper-event",
|
753 |
+
}
|
754 |
+
if data_args.dataset_name is not None:
|
755 |
+
kwargs["dataset_tags"] = data_args.dataset_name
|
756 |
+
if data_args.dataset_config_name is not None:
|
757 |
+
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
758 |
+
else:
|
759 |
+
kwargs["dataset"] = data_args.dataset_name
|
760 |
+
if "common_voice" in data_args.dataset_name:
|
761 |
+
kwargs["language"] = data_args.dataset_config_name[:2]
|
762 |
+
if model_args.model_index_name is not None:
|
763 |
+
kwargs["model_name"] = model_args.model_index_name
|
764 |
+
|
765 |
+
if training_args.push_to_hub:
|
766 |
+
trainer.push_to_hub(**kwargs)
|
767 |
+
else:
|
768 |
+
trainer.create_model_card(**kwargs)
|
769 |
+
|
770 |
+
return results
|
771 |
+
|
772 |
+
|
773 |
+
if __name__ == "__main__":
|
774 |
+
main()
|
run_3/src/setup_env.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sudo add-apt-repository -y ppa:jonathonf/ffmpeg-4
|
2 |
+
sudo apt update
|
3 |
+
sudo apt install -y ffmpeg
|
4 |
+
|
5 |
+
sudo apt-get install git-lfs
|
6 |
+
|
7 |
+
sudo apt-get install tmux
|
8 |
+
|
9 |
+
cd ~
|
10 |
+
echo "executing env setup from $(pwd)"
|
11 |
+
|
12 |
+
python3 -m venv ~/python_venvs/hf_env
|
13 |
+
source ~/python_venvs/hf_env/bin/activate
|
14 |
+
echo "source ~/python_venvs/hf_env/bin/activate" >> ~/.bashrc
|
15 |
+
|
16 |
+
git clone https://github.com/yks72p/whisper-finetuning-be
|
17 |
+
pip install -r ~/whisper-finetuning-be/requirements.txt
|
18 |
+
|
19 |
+
git config --global credential.helper store
|
20 |
+
huggingface-cli login
|
21 |
+
|
22 |
+
echo "env setup"
|
23 |
+
echo "! PLEASE LOGIN INTO GIT TO BE ABLE TO PUSH TO HF HUB !"
|
24 |
+
echo "> git config --globase user.name <user_name>"
|
25 |
+
echo "> git config --globase user.email <user_email>"
|
run_3/tensorboard_logs/Dec21_18-04-41_129-146-110-116/1671647715.531084/events.out.tfevents.1671647715.129-146-110-116.757634.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d68bdb4e04b2431e922781a2bb93f62871d2bbaada836f73585a306a8f294c5f
|
3 |
+
size 5865
|
run_3/tensorboard_logs/Dec21_18-04-41_129-146-110-116/events.out.tfevents.1671647715.129-146-110-116.757634.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29b51173f15425dc32fa8c0c89cca94183f227ba4eafe1dc989708a40685508f
|
3 |
+
size 25519
|
run_3/tensorboard_logs/Dec21_18-04-41_129-146-110-116/events.out.tfevents.1671730045.129-146-110-116.757634.2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad93c9b86db2c6e598f074d486ea716bafcdd685f5da86c3d7b6fa25ba9253d2
|
3 |
+
size 358
|
run_3/trainer_state.json
ADDED
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 6.719567920299668,
|
3 |
+
"best_model_checkpoint": "./checkpoint-6000",
|
4 |
+
"epoch": 1.0958333333333334,
|
5 |
+
"global_step": 6000,
|
6 |
+
"is_hyper_param_search": false,
|
7 |
+
"is_local_process_zero": true,
|
8 |
+
"is_world_process_zero": true,
|
9 |
+
"log_history": [
|
10 |
+
{
|
11 |
+
"epoch": 0.0,
|
12 |
+
"learning_rate": 2.999666666666667e-05,
|
13 |
+
"loss": 0.0175,
|
14 |
+
"step": 1
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"epoch": 0.01,
|
18 |
+
"learning_rate": 2.9833333333333335e-05,
|
19 |
+
"loss": 0.0176,
|
20 |
+
"step": 50
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"epoch": 0.02,
|
24 |
+
"learning_rate": 2.966666666666667e-05,
|
25 |
+
"loss": 0.0195,
|
26 |
+
"step": 100
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"epoch": 0.03,
|
30 |
+
"learning_rate": 2.95e-05,
|
31 |
+
"loss": 0.0191,
|
32 |
+
"step": 150
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"epoch": 0.03,
|
36 |
+
"learning_rate": 2.9333333333333333e-05,
|
37 |
+
"loss": 0.0191,
|
38 |
+
"step": 200
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"epoch": 0.04,
|
42 |
+
"learning_rate": 2.9166666666666666e-05,
|
43 |
+
"loss": 0.0204,
|
44 |
+
"step": 250
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"epoch": 0.05,
|
48 |
+
"learning_rate": 2.9e-05,
|
49 |
+
"loss": 0.0191,
|
50 |
+
"step": 300
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"epoch": 0.06,
|
54 |
+
"learning_rate": 2.8833333333333334e-05,
|
55 |
+
"loss": 0.0166,
|
56 |
+
"step": 350
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"epoch": 0.07,
|
60 |
+
"learning_rate": 2.8666666666666668e-05,
|
61 |
+
"loss": 0.0171,
|
62 |
+
"step": 400
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"epoch": 0.07,
|
66 |
+
"learning_rate": 2.8499999999999998e-05,
|
67 |
+
"loss": 0.0178,
|
68 |
+
"step": 450
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"epoch": 0.08,
|
72 |
+
"learning_rate": 2.8333333333333332e-05,
|
73 |
+
"loss": 0.0246,
|
74 |
+
"step": 500
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"epoch": 0.09,
|
78 |
+
"learning_rate": 2.8166666666666666e-05,
|
79 |
+
"loss": 0.0171,
|
80 |
+
"step": 550
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"epoch": 0.1,
|
84 |
+
"learning_rate": 2.8e-05,
|
85 |
+
"loss": 0.0187,
|
86 |
+
"step": 600
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"epoch": 0.11,
|
90 |
+
"learning_rate": 2.7833333333333337e-05,
|
91 |
+
"loss": 0.0267,
|
92 |
+
"step": 650
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"epoch": 0.12,
|
96 |
+
"learning_rate": 2.766666666666667e-05,
|
97 |
+
"loss": 0.0336,
|
98 |
+
"step": 700
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"epoch": 0.12,
|
102 |
+
"learning_rate": 2.75e-05,
|
103 |
+
"loss": 0.0325,
|
104 |
+
"step": 750
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"epoch": 0.13,
|
108 |
+
"learning_rate": 2.7333333333333335e-05,
|
109 |
+
"loss": 0.0315,
|
110 |
+
"step": 800
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"epoch": 0.14,
|
114 |
+
"learning_rate": 2.716666666666667e-05,
|
115 |
+
"loss": 0.0318,
|
116 |
+
"step": 850
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"epoch": 0.15,
|
120 |
+
"learning_rate": 2.7000000000000002e-05,
|
121 |
+
"loss": 0.0311,
|
122 |
+
"step": 900
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"epoch": 0.16,
|
126 |
+
"learning_rate": 2.6833333333333336e-05,
|
127 |
+
"loss": 0.0304,
|
128 |
+
"step": 950
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"epoch": 0.17,
|
132 |
+
"learning_rate": 2.6666666666666667e-05,
|
133 |
+
"loss": 0.0275,
|
134 |
+
"step": 1000
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"epoch": 0.17,
|
138 |
+
"eval_loss": 0.08357452601194382,
|
139 |
+
"eval_runtime": 5071.0402,
|
140 |
+
"eval_samples_per_second": 3.13,
|
141 |
+
"eval_steps_per_second": 0.098,
|
142 |
+
"eval_wer": 7.786304277240581,
|
143 |
+
"step": 1000
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"epoch": 0.17,
|
147 |
+
"learning_rate": 2.65e-05,
|
148 |
+
"loss": 0.03,
|
149 |
+
"step": 1050
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"epoch": 0.18,
|
153 |
+
"learning_rate": 2.6333333333333334e-05,
|
154 |
+
"loss": 0.0275,
|
155 |
+
"step": 1100
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"epoch": 0.19,
|
159 |
+
"learning_rate": 2.6166666666666668e-05,
|
160 |
+
"loss": 0.0266,
|
161 |
+
"step": 1150
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"epoch": 0.2,
|
165 |
+
"learning_rate": 2.6000000000000002e-05,
|
166 |
+
"loss": 0.027,
|
167 |
+
"step": 1200
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"epoch": 0.21,
|
171 |
+
"learning_rate": 2.5833333333333336e-05,
|
172 |
+
"loss": 0.0274,
|
173 |
+
"step": 1250
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"epoch": 0.22,
|
177 |
+
"learning_rate": 2.5666666666666666e-05,
|
178 |
+
"loss": 0.0264,
|
179 |
+
"step": 1300
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"epoch": 0.23,
|
183 |
+
"learning_rate": 2.55e-05,
|
184 |
+
"loss": 0.0248,
|
185 |
+
"step": 1350
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"epoch": 0.23,
|
189 |
+
"learning_rate": 2.5333333333333334e-05,
|
190 |
+
"loss": 0.0261,
|
191 |
+
"step": 1400
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"epoch": 0.24,
|
195 |
+
"learning_rate": 2.5166666666666667e-05,
|
196 |
+
"loss": 0.0225,
|
197 |
+
"step": 1450
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"epoch": 0.25,
|
201 |
+
"learning_rate": 2.5e-05,
|
202 |
+
"loss": 0.0255,
|
203 |
+
"step": 1500
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"epoch": 0.26,
|
207 |
+
"learning_rate": 2.483333333333333e-05,
|
208 |
+
"loss": 0.025,
|
209 |
+
"step": 1550
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"epoch": 0.27,
|
213 |
+
"learning_rate": 2.4666666666666665e-05,
|
214 |
+
"loss": 0.0211,
|
215 |
+
"step": 1600
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"epoch": 0.28,
|
219 |
+
"learning_rate": 2.45e-05,
|
220 |
+
"loss": 0.0242,
|
221 |
+
"step": 1650
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"epoch": 0.28,
|
225 |
+
"learning_rate": 2.4333333333333333e-05,
|
226 |
+
"loss": 0.023,
|
227 |
+
"step": 1700
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"epoch": 0.29,
|
231 |
+
"learning_rate": 2.4166666666666667e-05,
|
232 |
+
"loss": 0.0177,
|
233 |
+
"step": 1750
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"epoch": 0.3,
|
237 |
+
"learning_rate": 2.4e-05,
|
238 |
+
"loss": 0.0195,
|
239 |
+
"step": 1800
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"epoch": 0.31,
|
243 |
+
"learning_rate": 2.383333333333333e-05,
|
244 |
+
"loss": 0.0195,
|
245 |
+
"step": 1850
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"epoch": 0.32,
|
249 |
+
"learning_rate": 2.3666666666666665e-05,
|
250 |
+
"loss": 0.0193,
|
251 |
+
"step": 1900
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"epoch": 0.33,
|
255 |
+
"learning_rate": 2.3500000000000002e-05,
|
256 |
+
"loss": 0.019,
|
257 |
+
"step": 1950
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"epoch": 0.33,
|
261 |
+
"learning_rate": 2.3333333333333336e-05,
|
262 |
+
"loss": 0.016,
|
263 |
+
"step": 2000
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"epoch": 0.33,
|
267 |
+
"eval_loss": 0.08282392472028732,
|
268 |
+
"eval_runtime": 5050.4193,
|
269 |
+
"eval_samples_per_second": 3.143,
|
270 |
+
"eval_steps_per_second": 0.098,
|
271 |
+
"eval_wer": 7.288969138295598,
|
272 |
+
"step": 2000
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"epoch": 0.34,
|
276 |
+
"learning_rate": 2.316666666666667e-05,
|
277 |
+
"loss": 0.0194,
|
278 |
+
"step": 2050
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"epoch": 0.35,
|
282 |
+
"learning_rate": 2.3000000000000003e-05,
|
283 |
+
"loss": 0.0163,
|
284 |
+
"step": 2100
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"epoch": 0.36,
|
288 |
+
"learning_rate": 2.2833333333333334e-05,
|
289 |
+
"loss": 0.0151,
|
290 |
+
"step": 2150
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"epoch": 0.37,
|
294 |
+
"learning_rate": 2.2666666666666668e-05,
|
295 |
+
"loss": 0.0175,
|
296 |
+
"step": 2200
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"epoch": 0.38,
|
300 |
+
"learning_rate": 2.25e-05,
|
301 |
+
"loss": 0.0174,
|
302 |
+
"step": 2250
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"epoch": 0.38,
|
306 |
+
"learning_rate": 2.2333333333333335e-05,
|
307 |
+
"loss": 0.015,
|
308 |
+
"step": 2300
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"epoch": 0.39,
|
312 |
+
"learning_rate": 2.216666666666667e-05,
|
313 |
+
"loss": 0.0178,
|
314 |
+
"step": 2350
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"epoch": 0.4,
|
318 |
+
"learning_rate": 2.2e-05,
|
319 |
+
"loss": 0.0201,
|
320 |
+
"step": 2400
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"epoch": 0.41,
|
324 |
+
"learning_rate": 2.1833333333333333e-05,
|
325 |
+
"loss": 0.0173,
|
326 |
+
"step": 2450
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"epoch": 0.42,
|
330 |
+
"learning_rate": 2.1666666666666667e-05,
|
331 |
+
"loss": 0.0171,
|
332 |
+
"step": 2500
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"epoch": 0.42,
|
336 |
+
"learning_rate": 2.15e-05,
|
337 |
+
"loss": 0.024,
|
338 |
+
"step": 2550
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"epoch": 0.43,
|
342 |
+
"learning_rate": 2.1333333333333335e-05,
|
343 |
+
"loss": 0.0165,
|
344 |
+
"step": 2600
|
345 |
+
},
|
346 |
+
{
|
347 |
+
"epoch": 0.44,
|
348 |
+
"learning_rate": 2.116666666666667e-05,
|
349 |
+
"loss": 0.0174,
|
350 |
+
"step": 2650
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"epoch": 0.45,
|
354 |
+
"learning_rate": 2.1e-05,
|
355 |
+
"loss": 0.0175,
|
356 |
+
"step": 2700
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"epoch": 0.46,
|
360 |
+
"learning_rate": 2.0833333333333333e-05,
|
361 |
+
"loss": 0.0234,
|
362 |
+
"step": 2750
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"epoch": 0.47,
|
366 |
+
"learning_rate": 2.0666666666666666e-05,
|
367 |
+
"loss": 0.0201,
|
368 |
+
"step": 2800
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"epoch": 0.47,
|
372 |
+
"learning_rate": 2.05e-05,
|
373 |
+
"loss": 0.0183,
|
374 |
+
"step": 2850
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"epoch": 0.48,
|
378 |
+
"learning_rate": 2.0333333333333334e-05,
|
379 |
+
"loss": 0.0194,
|
380 |
+
"step": 2900
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"epoch": 0.49,
|
384 |
+
"learning_rate": 2.0166666666666668e-05,
|
385 |
+
"loss": 0.0194,
|
386 |
+
"step": 2950
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"epoch": 0.5,
|
390 |
+
"learning_rate": 1.9999999999999998e-05,
|
391 |
+
"loss": 0.0164,
|
392 |
+
"step": 3000
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"epoch": 0.5,
|
396 |
+
"eval_loss": 0.0816347748041153,
|
397 |
+
"eval_runtime": 5044.7426,
|
398 |
+
"eval_samples_per_second": 3.146,
|
399 |
+
"eval_steps_per_second": 0.098,
|
400 |
+
"eval_wer": 7.182849857055744,
|
401 |
+
"step": 3000
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"epoch": 0.51,
|
405 |
+
"learning_rate": 1.9833333333333332e-05,
|
406 |
+
"loss": 0.0204,
|
407 |
+
"step": 3050
|
408 |
+
},
|
409 |
+
{
|
410 |
+
"epoch": 0.52,
|
411 |
+
"learning_rate": 1.9666666666666666e-05,
|
412 |
+
"loss": 0.0203,
|
413 |
+
"step": 3100
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"epoch": 0.53,
|
417 |
+
"learning_rate": 1.9503333333333334e-05,
|
418 |
+
"loss": 0.0198,
|
419 |
+
"step": 3150
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"epoch": 0.53,
|
423 |
+
"learning_rate": 1.9336666666666667e-05,
|
424 |
+
"loss": 0.0224,
|
425 |
+
"step": 3200
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"epoch": 0.54,
|
429 |
+
"learning_rate": 1.917e-05,
|
430 |
+
"loss": 0.0251,
|
431 |
+
"step": 3250
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"epoch": 0.55,
|
435 |
+
"learning_rate": 1.9003333333333335e-05,
|
436 |
+
"loss": 0.0201,
|
437 |
+
"step": 3300
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"epoch": 0.56,
|
441 |
+
"learning_rate": 1.883666666666667e-05,
|
442 |
+
"loss": 0.019,
|
443 |
+
"step": 3350
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"epoch": 0.57,
|
447 |
+
"learning_rate": 1.867e-05,
|
448 |
+
"loss": 0.0185,
|
449 |
+
"step": 3400
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"epoch": 0.57,
|
453 |
+
"learning_rate": 1.8503333333333333e-05,
|
454 |
+
"loss": 0.0216,
|
455 |
+
"step": 3450
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"epoch": 0.58,
|
459 |
+
"learning_rate": 1.8336666666666667e-05,
|
460 |
+
"loss": 0.0203,
|
461 |
+
"step": 3500
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"epoch": 0.59,
|
465 |
+
"learning_rate": 1.817e-05,
|
466 |
+
"loss": 0.0166,
|
467 |
+
"step": 3550
|
468 |
+
},
|
469 |
+
{
|
470 |
+
"epoch": 0.6,
|
471 |
+
"learning_rate": 1.8003333333333334e-05,
|
472 |
+
"loss": 0.0183,
|
473 |
+
"step": 3600
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"epoch": 0.61,
|
477 |
+
"learning_rate": 1.7836666666666665e-05,
|
478 |
+
"loss": 0.0194,
|
479 |
+
"step": 3650
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"epoch": 0.62,
|
483 |
+
"learning_rate": 1.767e-05,
|
484 |
+
"loss": 0.0187,
|
485 |
+
"step": 3700
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"epoch": 0.62,
|
489 |
+
"learning_rate": 1.7503333333333332e-05,
|
490 |
+
"loss": 0.0184,
|
491 |
+
"step": 3750
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"epoch": 0.63,
|
495 |
+
"learning_rate": 1.7336666666666666e-05,
|
496 |
+
"loss": 0.0261,
|
497 |
+
"step": 3800
|
498 |
+
},
|
499 |
+
{
|
500 |
+
"epoch": 0.64,
|
501 |
+
"learning_rate": 1.717e-05,
|
502 |
+
"loss": 0.026,
|
503 |
+
"step": 3850
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"epoch": 0.65,
|
507 |
+
"learning_rate": 1.7003333333333334e-05,
|
508 |
+
"loss": 0.0163,
|
509 |
+
"step": 3900
|
510 |
+
},
|
511 |
+
{
|
512 |
+
"epoch": 0.66,
|
513 |
+
"learning_rate": 1.6836666666666664e-05,
|
514 |
+
"loss": 0.0148,
|
515 |
+
"step": 3950
|
516 |
+
},
|
517 |
+
{
|
518 |
+
"epoch": 0.67,
|
519 |
+
"learning_rate": 1.667e-05,
|
520 |
+
"loss": 0.0191,
|
521 |
+
"step": 4000
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"epoch": 0.67,
|
525 |
+
"eval_loss": 0.08142250031232834,
|
526 |
+
"eval_runtime": 5067.9975,
|
527 |
+
"eval_samples_per_second": 3.132,
|
528 |
+
"eval_steps_per_second": 0.098,
|
529 |
+
"eval_wer": 7.272338504668456,
|
530 |
+
"step": 4000
|
531 |
+
},
|
532 |
+
{
|
533 |
+
"epoch": 0.68,
|
534 |
+
"learning_rate": 1.6503333333333335e-05,
|
535 |
+
"loss": 0.0172,
|
536 |
+
"step": 4050
|
537 |
+
},
|
538 |
+
{
|
539 |
+
"epoch": 0.68,
|
540 |
+
"learning_rate": 1.633666666666667e-05,
|
541 |
+
"loss": 0.0167,
|
542 |
+
"step": 4100
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"epoch": 0.69,
|
546 |
+
"learning_rate": 1.6170000000000003e-05,
|
547 |
+
"loss": 0.0189,
|
548 |
+
"step": 4150
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"epoch": 0.7,
|
552 |
+
"learning_rate": 1.6003333333333337e-05,
|
553 |
+
"loss": 0.029,
|
554 |
+
"step": 4200
|
555 |
+
},
|
556 |
+
{
|
557 |
+
"epoch": 0.71,
|
558 |
+
"learning_rate": 1.5836666666666667e-05,
|
559 |
+
"loss": 0.0293,
|
560 |
+
"step": 4250
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"epoch": 0.72,
|
564 |
+
"learning_rate": 1.567e-05,
|
565 |
+
"loss": 0.0206,
|
566 |
+
"step": 4300
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"epoch": 0.72,
|
570 |
+
"learning_rate": 1.5503333333333335e-05,
|
571 |
+
"loss": 0.0204,
|
572 |
+
"step": 4350
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"epoch": 0.73,
|
576 |
+
"learning_rate": 1.533666666666667e-05,
|
577 |
+
"loss": 0.0235,
|
578 |
+
"step": 4400
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"epoch": 0.74,
|
582 |
+
"learning_rate": 1.5170000000000002e-05,
|
583 |
+
"loss": 0.0304,
|
584 |
+
"step": 4450
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"epoch": 0.75,
|
588 |
+
"learning_rate": 1.5003333333333333e-05,
|
589 |
+
"loss": 0.0181,
|
590 |
+
"step": 4500
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"epoch": 0.76,
|
594 |
+
"learning_rate": 1.4836666666666668e-05,
|
595 |
+
"loss": 0.0262,
|
596 |
+
"step": 4550
|
597 |
+
},
|
598 |
+
{
|
599 |
+
"epoch": 0.77,
|
600 |
+
"learning_rate": 1.467e-05,
|
601 |
+
"loss": 0.0184,
|
602 |
+
"step": 4600
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"epoch": 0.78,
|
606 |
+
"learning_rate": 1.4503333333333334e-05,
|
607 |
+
"loss": 0.018,
|
608 |
+
"step": 4650
|
609 |
+
},
|
610 |
+
{
|
611 |
+
"epoch": 0.78,
|
612 |
+
"learning_rate": 1.4336666666666666e-05,
|
613 |
+
"loss": 0.0173,
|
614 |
+
"step": 4700
|
615 |
+
},
|
616 |
+
{
|
617 |
+
"epoch": 0.79,
|
618 |
+
"learning_rate": 1.417e-05,
|
619 |
+
"loss": 0.0172,
|
620 |
+
"step": 4750
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"epoch": 0.8,
|
624 |
+
"learning_rate": 1.4003333333333334e-05,
|
625 |
+
"loss": 0.0165,
|
626 |
+
"step": 4800
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"epoch": 0.81,
|
630 |
+
"learning_rate": 1.3836666666666666e-05,
|
631 |
+
"loss": 0.0183,
|
632 |
+
"step": 4850
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"epoch": 0.82,
|
636 |
+
"learning_rate": 1.367e-05,
|
637 |
+
"loss": 0.0174,
|
638 |
+
"step": 4900
|
639 |
+
},
|
640 |
+
{
|
641 |
+
"epoch": 0.82,
|
642 |
+
"learning_rate": 1.3503333333333333e-05,
|
643 |
+
"loss": 0.0172,
|
644 |
+
"step": 4950
|
645 |
+
},
|
646 |
+
{
|
647 |
+
"epoch": 0.83,
|
648 |
+
"learning_rate": 1.3336666666666667e-05,
|
649 |
+
"loss": 0.0142,
|
650 |
+
"step": 5000
|
651 |
+
},
|
652 |
+
{
|
653 |
+
"epoch": 0.83,
|
654 |
+
"eval_loss": 0.07993897795677185,
|
655 |
+
"eval_runtime": 5054.3965,
|
656 |
+
"eval_samples_per_second": 3.14,
|
657 |
+
"eval_steps_per_second": 0.098,
|
658 |
+
"eval_wer": 6.939725832125633,
|
659 |
+
"step": 5000
|
660 |
+
},
|
661 |
+
{
|
662 |
+
"epoch": 0.84,
|
663 |
+
"learning_rate": 1.3170000000000001e-05,
|
664 |
+
"loss": 0.0155,
|
665 |
+
"step": 5050
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"epoch": 0.85,
|
669 |
+
"learning_rate": 1.3003333333333335e-05,
|
670 |
+
"loss": 0.0164,
|
671 |
+
"step": 5100
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"epoch": 0.86,
|
675 |
+
"learning_rate": 1.2836666666666667e-05,
|
676 |
+
"loss": 0.0157,
|
677 |
+
"step": 5150
|
678 |
+
},
|
679 |
+
{
|
680 |
+
"epoch": 0.87,
|
681 |
+
"learning_rate": 1.267e-05,
|
682 |
+
"loss": 0.0151,
|
683 |
+
"step": 5200
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"epoch": 0.88,
|
687 |
+
"learning_rate": 1.2503333333333334e-05,
|
688 |
+
"loss": 0.0148,
|
689 |
+
"step": 5250
|
690 |
+
},
|
691 |
+
{
|
692 |
+
"epoch": 0.88,
|
693 |
+
"learning_rate": 1.2336666666666667e-05,
|
694 |
+
"loss": 0.017,
|
695 |
+
"step": 5300
|
696 |
+
},
|
697 |
+
{
|
698 |
+
"epoch": 0.89,
|
699 |
+
"learning_rate": 1.217e-05,
|
700 |
+
"loss": 0.0147,
|
701 |
+
"step": 5350
|
702 |
+
},
|
703 |
+
{
|
704 |
+
"epoch": 0.9,
|
705 |
+
"learning_rate": 1.2003333333333332e-05,
|
706 |
+
"loss": 0.0196,
|
707 |
+
"step": 5400
|
708 |
+
},
|
709 |
+
{
|
710 |
+
"epoch": 1.0,
|
711 |
+
"learning_rate": 1.1836666666666666e-05,
|
712 |
+
"loss": 0.0165,
|
713 |
+
"step": 5450
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"epoch": 1.01,
|
717 |
+
"learning_rate": 1.167e-05,
|
718 |
+
"loss": 0.0096,
|
719 |
+
"step": 5500
|
720 |
+
},
|
721 |
+
{
|
722 |
+
"epoch": 1.02,
|
723 |
+
"learning_rate": 1.1503333333333332e-05,
|
724 |
+
"loss": 0.0075,
|
725 |
+
"step": 5550
|
726 |
+
},
|
727 |
+
{
|
728 |
+
"epoch": 1.03,
|
729 |
+
"learning_rate": 1.1336666666666668e-05,
|
730 |
+
"loss": 0.0071,
|
731 |
+
"step": 5600
|
732 |
+
},
|
733 |
+
{
|
734 |
+
"epoch": 1.04,
|
735 |
+
"learning_rate": 1.1170000000000001e-05,
|
736 |
+
"loss": 0.0068,
|
737 |
+
"step": 5650
|
738 |
+
},
|
739 |
+
{
|
740 |
+
"epoch": 1.05,
|
741 |
+
"learning_rate": 1.1003333333333334e-05,
|
742 |
+
"loss": 0.0068,
|
743 |
+
"step": 5700
|
744 |
+
},
|
745 |
+
{
|
746 |
+
"epoch": 1.05,
|
747 |
+
"learning_rate": 1.0836666666666667e-05,
|
748 |
+
"loss": 0.0052,
|
749 |
+
"step": 5750
|
750 |
+
},
|
751 |
+
{
|
752 |
+
"epoch": 1.06,
|
753 |
+
"learning_rate": 1.0670000000000001e-05,
|
754 |
+
"loss": 0.0067,
|
755 |
+
"step": 5800
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"epoch": 1.07,
|
759 |
+
"learning_rate": 1.0503333333333333e-05,
|
760 |
+
"loss": 0.0087,
|
761 |
+
"step": 5850
|
762 |
+
},
|
763 |
+
{
|
764 |
+
"epoch": 1.08,
|
765 |
+
"learning_rate": 1.0336666666666667e-05,
|
766 |
+
"loss": 0.0078,
|
767 |
+
"step": 5900
|
768 |
+
},
|
769 |
+
{
|
770 |
+
"epoch": 1.09,
|
771 |
+
"learning_rate": 1.0170000000000001e-05,
|
772 |
+
"loss": 0.0075,
|
773 |
+
"step": 5950
|
774 |
+
},
|
775 |
+
{
|
776 |
+
"epoch": 1.1,
|
777 |
+
"learning_rate": 1.0003333333333333e-05,
|
778 |
+
"loss": 0.0076,
|
779 |
+
"step": 6000
|
780 |
+
},
|
781 |
+
{
|
782 |
+
"epoch": 1.1,
|
783 |
+
"eval_loss": 0.08352651447057724,
|
784 |
+
"eval_runtime": 5093.2556,
|
785 |
+
"eval_samples_per_second": 3.116,
|
786 |
+
"eval_steps_per_second": 0.097,
|
787 |
+
"eval_wer": 6.719567920299668,
|
788 |
+
"step": 6000
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"epoch": 1.1,
|
792 |
+
"step": 6000,
|
793 |
+
"total_flos": 1.1080467313606656e+20,
|
794 |
+
"train_loss": 0.0194823435023427,
|
795 |
+
"train_runtime": 77295.3787,
|
796 |
+
"train_samples_per_second": 4.968,
|
797 |
+
"train_steps_per_second": 0.078
|
798 |
+
}
|
799 |
+
],
|
800 |
+
"max_steps": 6000,
|
801 |
+
"num_train_epochs": 9223372036854775807,
|
802 |
+
"total_flos": 1.1080467313606656e+20,
|
803 |
+
"trial_name": null,
|
804 |
+
"trial_params": null
|
805 |
+
}
|