ales commited on
Commit
c4adc54
•
1 Parent(s): b519008
src/{run_base.sh → bash_runners/run_base.sh} RENAMED
File without changes
src/bash_runners/run_eval_cv11.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python src/run_eval_whisper_streaming \
2
+ --model_id="." \
3
+ --language="be" \
4
+ --dataset="mozilla-foundation/common_voice_11_0" \
5
+ --config="be" \
6
+ --split="test" \
7
+ --device="0" \
8
+ --batch_size="32" \
9
+ --streaming="True"
src/{run_small.sh → bash_runners/run_small.sh} RENAMED
@@ -1,5 +1,5 @@
1
  python src/run_speech_recognition_seq2seq_streaming.py \
2
- --model_name_or_path="openai/whisper-small" \
3
  --dataset_name="mozilla-foundation/common_voice_11_0" \
4
  --dataset_config_name="be" \
5
  --language="be" \
@@ -7,14 +7,14 @@ python src/run_speech_recognition_seq2seq_streaming.py \
7
  --eval_split_name="validation" \
8
  --model_index_name="Whisper Small Belarusian" \
9
  \
10
- --max_steps="12000" \
11
  --output_dir="./" \
12
  --per_device_train_batch_size="64" \
13
- --per_device_eval_batch_size="64" \
14
  --logging_steps="50" \
15
  --logging_first_step \
16
- --learning_rate="1e-4" \
17
- --warmup_steps="500" \
18
  --evaluation_strategy="steps" \
19
  --eval_steps="1000" \
20
  --save_strategy="steps" \
@@ -39,6 +39,5 @@ python src/run_speech_recognition_seq2seq_streaming.py \
39
  --do_normalize_eval \
40
  --streaming_train="True" \
41
  --streaming_eval="False" \
42
- --use_auth_token \
43
- --push_to_hub \
44
- --hub_model_id="ales/whisper-small-belarusian"
 
1
  python src/run_speech_recognition_seq2seq_streaming.py \
2
+ --model_name_or_path="ales/whisper-small-belarusian" \
3
  --dataset_name="mozilla-foundation/common_voice_11_0" \
4
  --dataset_config_name="be" \
5
  --language="be" \
 
7
  --eval_split_name="validation" \
8
  --model_index_name="Whisper Small Belarusian" \
9
  \
10
+ --max_steps="6000" \
11
  --output_dir="./" \
12
  --per_device_train_batch_size="64" \
13
+ --per_device_eval_batch_size="32" \
14
  --logging_steps="50" \
15
  --logging_first_step \
16
+ --learning_rate="3.5e-5" \
17
+ --warmup_steps="0" \
18
  --evaluation_strategy="steps" \
19
  --eval_steps="1000" \
20
  --save_strategy="steps" \
 
39
  --do_normalize_eval \
40
  --streaming_train="True" \
41
  --streaming_eval="False" \
42
+ --seed="43" \
43
+ --use_auth_token
 
src/{run_tiny_debug.sh → bash_runners/run_tiny_debug.sh} RENAMED
File without changes
src/belarusian_text_normalizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import regex
3
+ import unicodedata
4
+
5
+ from typing import Iterable
6
+
7
+
8
+ class BelarusianTextNormalizer:
9
+ """
10
+ Based on transformers.models.whisper.english_normalizer.BasicTextNormalizer
11
+ but with support not to remove certain characters.
12
+ e.g. apostrophe (') - a symbol from Belarusian alphabet - was removed using BasicTextNormalizer.
13
+ """
14
+
15
+ def __init__(self, split_letters: bool = False):
16
+ self.split_letters = split_letters
17
+ self.allowed_symbols = ("'",)
18
+
19
+ @staticmethod
20
+ def clean(s: str, allowed_symbols: Iterable[str] = None):
21
+ """
22
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
23
+ """
24
+ if allowed_symbols is None:
25
+ allowed_symbols = []
26
+ res = "".join(" " if unicodedata.category(c)[0] in "MSP" and c not in allowed_symbols else c
27
+ for c in unicodedata.normalize("NFKC", s))
28
+ return res
29
+
30
+ def __call__(self, s: str):
31
+ s = s.lower()
32
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
33
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
34
+ s = self.clean(s, allowed_symbols=self.allowed_symbols).lower()
35
+
36
+ if self.split_letters:
37
+ s = " ".join(regex.findall(r"\X", s, regex.U))
38
+
39
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
40
+
41
+ return s
src/readme.md CHANGED
@@ -39,6 +39,9 @@ The code in this repository is a modified version of code from
39
 
40
  ## Resuming training from exising checkpoint
41
  When resuming training from existing checkpoint:
 
 
 
42
  * it's better to save all `checkpoint-\d+` dirs. better not to rely on data saved to `output_dir` because:
43
  * not all data is saved to `output_dir`. e.g. following files are not saved to `output_dir`:
44
  `optimizer.pt`, `rng_state.pth`, `scaler.pt`, `scheduler.pt`. so can't resume training in a correct way from
@@ -70,9 +73,16 @@ When resuming training from existing checkpoint:
70
  but does StreamingDataset have any epochs?
71
  * does streaming mode support parallel data load and processing?<br>
72
  when using non-streaming mode we can use `dataset.map(..., num_proc=<num_proc>)`
 
 
 
73
 
74
 
75
  ## Notes:
 
 
 
 
76
  * using CommonVoice 11 dataset in a streaming way.<br>
77
  use `streaming=True` for train & validation & test.<br>
78
  as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
 
39
 
40
  ## Resuming training from exising checkpoint
41
  When resuming training from existing checkpoint:
42
+ * when using streaming, epoch will get reset to 0. that means order of items passed to a model would be the same,
43
+ if the seed does not change. actual train_dataloader seed would be:
44
+ `train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)`
45
  * it's better to save all `checkpoint-\d+` dirs. better not to rely on data saved to `output_dir` because:
46
  * not all data is saved to `output_dir`. e.g. following files are not saved to `output_dir`:
47
  `optimizer.pt`, `rng_state.pth`, `scaler.pt`, `scheduler.pt`. so can't resume training in a correct way from
 
73
  but does StreamingDataset have any epochs?
74
  * does streaming mode support parallel data load and processing?<br>
75
  when using non-streaming mode we can use `dataset.map(..., num_proc=<num_proc>)`
76
+ * I got CUDA out of memory error when tried to launch a second training run for Whisper Small model.
77
+ training params are almost the same: `--per_device_train_batch_size="64"`
78
+ the only thing changed is that now evaluation dataset now doesn't use streaming.
79
 
80
 
81
  ## Notes:
82
+ * Common Voice 11 dataset
83
+ [uploaded to HuggingFace](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0)
84
+ has only single voicing of each sentence in each split (train, validation, test).<br>
85
+ Much more audiofiles should be available on Common Voice so that each sentence is voiced multiple times by different people
86
  * using CommonVoice 11 dataset in a streaming way.<br>
87
  use `streaming=True` for train & validation & test.<br>
88
  as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
src/run_eval_whisper_streaming.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from transformers import pipeline
4
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
5
+ from datasets import load_dataset, Audio
6
+ import evaluate
7
+
8
+ from belarusian_text_normalizer import BelarusianTextNormalizer
9
+
10
+
11
+ wer_metric = evaluate.load("wer")
12
+
13
+
14
+ def is_target_text_in_range(ref):
15
+ if ref.strip() == "ignore time segment in scoring":
16
+ return False
17
+ else:
18
+ return ref.strip() != ""
19
+
20
+
21
+ def get_text(sample):
22
+ if "text" in sample:
23
+ return sample["text"]
24
+ elif "sentence" in sample:
25
+ return sample["sentence"]
26
+ elif "normalized_text" in sample:
27
+ return sample["normalized_text"]
28
+ elif "transcript" in sample:
29
+ return sample["transcript"]
30
+ elif "transcription" in sample:
31
+ return sample["transcription"]
32
+ else:
33
+ raise ValueError(
34
+ f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
35
+ ".join{sample.keys()}. Ensure a text column name is present in the dataset."
36
+ )
37
+
38
+
39
+ whisper_norm = BelarusianTextNormalizer()
40
+
41
+
42
+ def normalise(batch):
43
+ batch["norm_text"] = whisper_norm(get_text(batch))
44
+ return batch
45
+
46
+
47
+ def data(dataset):
48
+ for i, item in enumerate(dataset):
49
+ yield {**item["audio"], "reference": item["norm_text"]}
50
+
51
+
52
+ def main(args):
53
+ batch_size = args.batch_size
54
+ whisper_asr = pipeline(
55
+ "automatic-speech-recognition", model=args.model_id, device=args.device
56
+ )
57
+
58
+ whisper_asr.model.config.forced_decoder_ids = (
59
+ whisper_asr.tokenizer.get_decoder_prompt_ids(
60
+ language=args.language, task="transcribe"
61
+ )
62
+ )
63
+
64
+ dataset = load_dataset(
65
+ args.dataset,
66
+ args.config,
67
+ split=args.split,
68
+ streaming=args.streaming,
69
+ use_auth_token=True,
70
+ )
71
+
72
+ # Only uncomment for debugging
73
+ dataset = dataset.take(args.max_eval_samples)
74
+
75
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
76
+ dataset = dataset.map(normalise)
77
+ dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
78
+
79
+ predictions = []
80
+ references = []
81
+
82
+ # run streamed inference
83
+ for out in whisper_asr(data(dataset), batch_size=batch_size):
84
+ predictions.append(whisper_norm(out["text"]))
85
+ references.append(out["reference"][0])
86
+
87
+ wer = wer_metric.compute(references=references, predictions=predictions)
88
+ wer = round(100 * wer, 2)
89
+
90
+ print("WER:", wer)
91
+ evaluate.push_to_hub(
92
+ model_id=args.model_id,
93
+ metric_value=wer,
94
+ metric_type="wer",
95
+ metric_name="WER",
96
+ dataset_name=args.dataset,
97
+ dataset_type=args.dataset,
98
+ dataset_split=args.split,
99
+ dataset_config=args.config,
100
+ task_type="automatic-speech-recognition",
101
+ task_name="Automatic Speech Recognition"
102
+ )
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+
108
+ parser.add_argument(
109
+ "--model_id",
110
+ type=str,
111
+ required=True,
112
+ help="Model identifier. Should be loadable with 🤗 Transformers",
113
+ )
114
+ parser.add_argument(
115
+ "--dataset",
116
+ type=str,
117
+ default="mozilla-foundation/common_voice_11_0",
118
+ help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
119
+ )
120
+ parser.add_argument(
121
+ "--config",
122
+ type=str,
123
+ required=True,
124
+ help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
125
+ )
126
+ parser.add_argument(
127
+ "--split",
128
+ type=str,
129
+ default="test",
130
+ help="Split of the dataset. *E.g.* `'test'`",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--device",
135
+ type=int,
136
+ default=-1,
137
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
138
+ )
139
+ parser.add_argument(
140
+ "--batch_size",
141
+ type=int,
142
+ default=16,
143
+ help="Number of samples to go through each streamed batch.",
144
+ )
145
+ parser.add_argument(
146
+ "--max_eval_samples",
147
+ type=int,
148
+ default=None,
149
+ help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
150
+ )
151
+ parser.add_argument(
152
+ "--streaming",
153
+ type=bool,
154
+ default=True,
155
+ help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
156
+ )
157
+ parser.add_argument(
158
+ "--language",
159
+ type=str,
160
+ required=True,
161
+ help="Two letter language code for the transcription language, e.g. use 'en' for English.",
162
+ )
163
+ args = parser.parse_args()
164
+
165
+ main(args)
src/run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -24,9 +24,6 @@ import logging
24
  import os
25
  import sys
26
  import datetime
27
- import re
28
- import regex
29
- import unicodedata
30
  from dataclasses import dataclass, field
31
  from typing import Any, Dict, List, Optional, Union, Iterable
32
 
@@ -54,6 +51,7 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
54
  from transformers.utils import check_min_version, send_example_telemetry
55
  from transformers.utils.versions import require_version
56
 
 
57
 
58
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
59
  check_min_version("4.25.0.dev0")
@@ -230,41 +228,6 @@ class DataTrainingArguments:
230
  )
231
 
232
 
233
- class BelarusianTextNormalizer:
234
- """
235
- Based on transformers.models.whisper.english_normalizer.BasicTextNormalizer
236
- but with support not to remove certain characters.
237
- e.g. apostrophe (') - a symbol from Belarusian alphabet - was removed using BasicTextNormalizer.
238
- """
239
-
240
- def __init__(self, split_letters: bool = False):
241
- self.split_letters = split_letters
242
- self.allowed_symbols = ("'",)
243
-
244
- @staticmethod
245
- def clean(s: str, allowed_symbols: Iterable[str] = None):
246
- """
247
- Replace any other markers, symbols, punctuations with a space, keeping diacritics
248
- """
249
- if allowed_symbols is None:
250
- allowed_symbols = []
251
- res = "".join(" " if unicodedata.category(c)[0] in "MSP" and c not in allowed_symbols else c
252
- for c in unicodedata.normalize("NFKC", s))
253
- return res
254
-
255
- def __call__(self, s: str):
256
- s = s.lower()
257
- s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
258
- s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
259
- s = self.clean(s, allowed_symbols=self.allowed_symbols).lower()
260
-
261
- if self.split_letters:
262
- s = " ".join(regex.findall(r"\X", s, regex.U))
263
-
264
- s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
265
-
266
- return s
267
-
268
 
269
  @dataclass
270
  class DataCollatorSpeechSeq2SeqWithPadding:
 
24
  import os
25
  import sys
26
  import datetime
 
 
 
27
  from dataclasses import dataclass, field
28
  from typing import Any, Dict, List, Optional, Union, Iterable
29
 
 
51
  from transformers.utils import check_min_version, send_example_telemetry
52
  from transformers.utils.versions import require_version
53
 
54
+ from belarusian_text_normalizer import BelarusianTextNormalizer
55
 
56
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
57
  check_min_version("4.25.0.dev0")
 
228
  )
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  @dataclass
233
  class DataCollatorSpeechSeq2SeqWithPadding: