pranay-j commited on
Commit
1b6a37a
1 Parent(s): b117b71

update model card README.md

Browse files
Files changed (2) hide show
  1. README.md +81 -0
  2. fine-tune-whisper-streaming.ipynb +165 -72
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ar
4
+ license: apache-2.0
5
+ tags:
6
+ - whisper-event
7
+ - generated_from_trainer
8
+ datasets:
9
+ - mozilla-foundation/common_voice_11_0
10
+ metrics:
11
+ - wer
12
+ model-index:
13
+ - name: Whisper Small ar
14
+ results:
15
+ - task:
16
+ name: Automatic Speech Recognition
17
+ type: automatic-speech-recognition
18
+ dataset:
19
+ name: common_voice_11_0
20
+ type: mozilla-foundation/common_voice_11_0
21
+ config: ar
22
+ split: test
23
+ args: ar
24
+ metrics:
25
+ - name: Wer
26
+ type: wer
27
+ value: 73.65866666666666
28
+ ---
29
+
30
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
31
+ should probably proofread and complete it, then remove this comment. -->
32
+
33
+ # Whisper Small ar
34
+
35
+ This model is a fine-tuned version of [openai/whisper-small](https://huggingface.co/openai/whisper-small) on the common_voice_11_0 dataset.
36
+ It achieves the following results on the evaluation set:
37
+ - Loss: 0.3855
38
+ - Wer: 73.6587
39
+
40
+ ## Model description
41
+
42
+ More information needed
43
+
44
+ ## Intended uses & limitations
45
+
46
+ More information needed
47
+
48
+ ## Training and evaluation data
49
+
50
+ More information needed
51
+
52
+ ## Training procedure
53
+
54
+ ### Training hyperparameters
55
+
56
+ The following hyperparameters were used during training:
57
+ - learning_rate: 1e-05
58
+ - train_batch_size: 16
59
+ - eval_batch_size: 16
60
+ - seed: 42
61
+ - gradient_accumulation_steps: 2
62
+ - total_train_batch_size: 32
63
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
64
+ - lr_scheduler_type: linear
65
+ - lr_scheduler_warmup_steps: 500
66
+ - training_steps: 2000
67
+ - mixed_precision_training: Native AMP
68
+
69
+ ### Training results
70
+
71
+ | Training Loss | Epoch | Step | Validation Loss | Wer |
72
+ |:-------------:|:-----:|:----:|:---------------:|:-------:|
73
+ | 0.1151 | 2.12 | 2000 | 0.3855 | 73.6587 |
74
+
75
+
76
+ ### Framework versions
77
+
78
+ - Transformers 4.26.0.dev0
79
+ - Pytorch 1.13.0+cu117
80
+ - Datasets 2.8.1.dev0
81
+ - Tokenizers 0.12.1
fine-tune-whisper-streaming.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
- "id": "5eb0df35",
6
  "metadata": {},
7
  "source": [
8
  "# Fine-Tune Whisper With 🤗 Transformers and Streaming Mode"
@@ -10,7 +10,7 @@
10
  },
11
  {
12
  "cell_type": "markdown",
13
- "id": "df8cc52a",
14
  "metadata": {},
15
  "source": [
16
  "In this Colab, we present a step-by-step guide on fine-tuning Whisper with Hugging Face 🤗 Transformers on 400 hours of speech data! Using streaming mode, we'll show how you can train a speech recongition model on any dataset, irrespective of size. With streaming mode, storage requirements are no longer a consideration: you can train a model on whatever dataset you want, even if it's download size exceeds your devices disk space. How can this be possible? It simply seems too good to be true! Well, rest assured it's not 😉 Carry on reading to find out more."
@@ -18,7 +18,7 @@
18
  },
19
  {
20
  "cell_type": "markdown",
21
- "id": "143f0620",
22
  "metadata": {},
23
  "source": [
24
  "## Introduction"
@@ -26,7 +26,7 @@
26
  },
27
  {
28
  "cell_type": "markdown",
29
- "id": "37745350",
30
  "metadata": {},
31
  "source": [
32
  "Speech recognition datasets are large. A typical speech dataset consists of approximately 100 hours of audio-transcription data, requiring upwards of 130GB of storage space for download and preparation. For most ASR researchers, this is already at the upper limit of what is feasible for disk space. So what happens when we want to train on a larger dataset? The full [LibriSpeech](https://huggingface.co/datasets/librispeech_asr) dataset consists of 960 hours of audio data. Kensho's [SPGISpeech](https://huggingface.co/datasets/kensho/spgispeech) contains 5,000 hours of audio data. ML Commons [People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech) contains **30,000+** hours of audio data! Do we need to bite the bullet and buy additional storage? Or is there a way we can train on all of these datasets with no disk drive requirements?\n",
@@ -42,7 +42,7 @@
42
  },
43
  {
44
  "cell_type": "markdown",
45
- "id": "e4c2618e",
46
  "metadata": {},
47
  "source": [
48
  "<figure>\n",
@@ -53,7 +53,7 @@
53
  },
54
  {
55
  "cell_type": "markdown",
56
- "id": "2daf6b53",
57
  "metadata": {},
58
  "source": [
59
  "This notebook provides a guide to fine-tuning on the task of _speech recognition_, which involves learning a\n",
@@ -92,7 +92,7 @@
92
  },
93
  {
94
  "cell_type": "markdown",
95
- "id": "409c4267",
96
  "metadata": {},
97
  "source": [
98
  "## Load Dataset with Streaming"
@@ -100,7 +100,7 @@
100
  },
101
  {
102
  "cell_type": "markdown",
103
- "id": "a0ce3342",
104
  "metadata": {},
105
  "source": [
106
  "This is where the magic happens! We'll first write a wrapper function around 🤗 Datasets `load_dataset` method. This function downloads the required splits using streaming mode by forcing `streaming=True` in the `load_dataset` method. Multiple splits can be combined (interleaved) by concatenating them with the \"+\" symbol when specifying the split name, e.g. `split=train+validation` will return a single split with the training and validation splits interleaved together. The function has the same arguments and key-word arguments as 🤗 Datasets `load_dataset` method, so we can use it in exactly the same way!"
@@ -109,7 +109,7 @@
109
  {
110
  "cell_type": "code",
111
  "execution_count": 1,
112
- "id": "e49e547c",
113
  "metadata": {},
114
  "outputs": [],
115
  "source": [
@@ -130,7 +130,7 @@
130
  },
131
  {
132
  "cell_type": "markdown",
133
- "id": "2b7da79d",
134
  "metadata": {},
135
  "source": [
136
  "We'll train our system on the Spanish split of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). We can see how much training data we have by viewing the [language page](https://commonvoice.mozilla.org/en/datasets) on the Common Voice website. The Spanish split has over 400 hours of labelled training data - that's enourmous! More than we could ever fit on a Google Colab or a standard workstation. But with streaming mode, we'll only download data as and when we need it, making training on this dataset possible!\n",
@@ -143,7 +143,7 @@
143
  {
144
  "cell_type": "code",
145
  "execution_count": 2,
146
- "id": "211fdf8c",
147
  "metadata": {},
148
  "outputs": [],
149
  "source": [
@@ -157,7 +157,7 @@
157
  },
158
  {
159
  "cell_type": "markdown",
160
- "id": "c483c427",
161
  "metadata": {},
162
  "source": [
163
  "## Prepare Processor and Pre-Process Data"
@@ -165,7 +165,7 @@
165
  },
166
  {
167
  "cell_type": "markdown",
168
- "id": "0a811a5e",
169
  "metadata": {},
170
  "source": [
171
  "The ASR pipeline can be de-composed into three stages: \n",
@@ -186,7 +186,7 @@
186
  {
187
  "cell_type": "code",
188
  "execution_count": 3,
189
- "id": "57b1770f",
190
  "metadata": {},
191
  "outputs": [],
192
  "source": [
@@ -197,7 +197,7 @@
197
  },
198
  {
199
  "cell_type": "markdown",
200
- "id": "cec6a161",
201
  "metadata": {},
202
  "source": [
203
  "### Pre-Process Data"
@@ -205,7 +205,7 @@
205
  },
206
  {
207
  "cell_type": "markdown",
208
- "id": "2166ce4e",
209
  "metadata": {},
210
  "source": [
211
  "Let's have a look at the dataset features. Pay particular attention to the `\"audio\"` column - this details the sampling rate of our audio inputs:"
@@ -214,7 +214,7 @@
214
  {
215
  "cell_type": "code",
216
  "execution_count": 4,
217
- "id": "c8031a36",
218
  "metadata": {},
219
  "outputs": [
220
  {
@@ -244,7 +244,7 @@
244
  },
245
  {
246
  "cell_type": "markdown",
247
- "id": "66ea302f",
248
  "metadata": {},
249
  "source": [
250
  "Since our input audio is sampled at 48kHz, we need to _downsample_ it to\n",
@@ -260,7 +260,7 @@
260
  {
261
  "cell_type": "code",
262
  "execution_count": 5,
263
- "id": "a7374f6b",
264
  "metadata": {},
265
  "outputs": [],
266
  "source": [
@@ -271,7 +271,7 @@
271
  },
272
  {
273
  "cell_type": "markdown",
274
- "id": "3d61f48f",
275
  "metadata": {},
276
  "source": [
277
  "We'll define our pre-processing strategy. We advise that you **do not** lower-case the transcriptions or remove punctuation unless mixing different datasets. This will enable you to fine-tune Whisper models that can predict punctuation and casing. Later, you will see how we can evaluate the predictions without punctuation or casing, so that the models benefit from the WER improvement obtained by normalising the transcriptions while still predicting fully formatted transcriptions."
@@ -280,7 +280,7 @@
280
  {
281
  "cell_type": "code",
282
  "execution_count": 6,
283
- "id": "ede17e4f",
284
  "metadata": {},
285
  "outputs": [],
286
  "source": [
@@ -294,7 +294,7 @@
294
  },
295
  {
296
  "cell_type": "markdown",
297
- "id": "f09fbe96",
298
  "metadata": {},
299
  "source": [
300
  "Now we can write a function to prepare our data ready for the model:\n",
@@ -307,7 +307,7 @@
307
  {
308
  "cell_type": "code",
309
  "execution_count": 7,
310
- "id": "e76b6638",
311
  "metadata": {},
312
  "outputs": [],
313
  "source": [
@@ -334,7 +334,7 @@
334
  },
335
  {
336
  "cell_type": "markdown",
337
- "id": "8255126a",
338
  "metadata": {},
339
  "source": [
340
  "We can apply the data preparation function to all of our training examples using 🤗 Datasets' `.map` method. We'll remove all of the columns from the raw training data, leaving just the `input_features` and `labels` defined in the `prepare_dataset` function:"
@@ -343,7 +343,7 @@
343
  {
344
  "cell_type": "code",
345
  "execution_count": 8,
346
- "id": "062d63f9",
347
  "metadata": {},
348
  "outputs": [],
349
  "source": [
@@ -352,7 +352,7 @@
352
  },
353
  {
354
  "cell_type": "markdown",
355
- "id": "eb19118e",
356
  "metadata": {},
357
  "source": [
358
  "We can now define how we shuffle the data in the train split. The size of the subset we load is set by the variable `buffer_size`. You can increase or decrease this depending on your memory constraints. In this example, the `buffer_size` is set to 500, meaning 500 samples are loaded before shuffling across the subset. The larger we set this value, the closer to True offline shuffling. The `seed` is set for reproducibility:"
@@ -361,7 +361,7 @@
361
  {
362
  "cell_type": "code",
363
  "execution_count": 9,
364
- "id": "25342d4d",
365
  "metadata": {},
366
  "outputs": [],
367
  "source": [
@@ -373,7 +373,7 @@
373
  },
374
  {
375
  "cell_type": "markdown",
376
- "id": "e76d637f",
377
  "metadata": {},
378
  "source": [
379
  "Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by the Whisper feature-extractor which could affect the stability of training. We define a function that returns `True` for samples that are less than 30s, and `False` for those that are longer:"
@@ -382,7 +382,7 @@
382
  {
383
  "cell_type": "code",
384
  "execution_count": 10,
385
- "id": "0bdf50a7",
386
  "metadata": {},
387
  "outputs": [],
388
  "source": [
@@ -394,7 +394,7 @@
394
  },
395
  {
396
  "cell_type": "markdown",
397
- "id": "1e98b230",
398
  "metadata": {},
399
  "source": [
400
  "We apply our filter function to all samples of our training dataset through 🤗 Datasets' `.filter` method:"
@@ -403,7 +403,7 @@
403
  {
404
  "cell_type": "code",
405
  "execution_count": 11,
406
- "id": "4496f42a",
407
  "metadata": {},
408
  "outputs": [],
409
  "source": [
@@ -415,7 +415,7 @@
415
  },
416
  {
417
  "cell_type": "markdown",
418
- "id": "cf987c26",
419
  "metadata": {},
420
  "source": [
421
  "## Training and Evaluation"
@@ -423,7 +423,7 @@
423
  },
424
  {
425
  "cell_type": "markdown",
426
- "id": "391bbd26",
427
  "metadata": {},
428
  "source": [
429
  "Now that we've prepared our data, we're ready to dive into the training pipeline. \n",
@@ -441,7 +441,7 @@
441
  },
442
  {
443
  "cell_type": "markdown",
444
- "id": "4226eb22",
445
  "metadata": {},
446
  "source": [
447
  "### Define a Data Collator"
@@ -449,7 +449,7 @@
449
  },
450
  {
451
  "cell_type": "markdown",
452
- "id": "2a8d1ca6",
453
  "metadata": {},
454
  "source": [
455
  "The data collator for a sequence-to-sequence speech model is unique in the sense that it \n",
@@ -473,7 +473,7 @@
473
  {
474
  "cell_type": "code",
475
  "execution_count": 12,
476
- "id": "c96bef0e",
477
  "metadata": {},
478
  "outputs": [],
479
  "source": [
@@ -512,7 +512,7 @@
512
  },
513
  {
514
  "cell_type": "markdown",
515
- "id": "3cb76e95",
516
  "metadata": {},
517
  "source": [
518
  "Let's initialise the data collator we've just defined:"
@@ -521,7 +521,7 @@
521
  {
522
  "cell_type": "code",
523
  "execution_count": 13,
524
- "id": "4979b797",
525
  "metadata": {},
526
  "outputs": [],
527
  "source": [
@@ -530,7 +530,7 @@
530
  },
531
  {
532
  "cell_type": "markdown",
533
- "id": "eb008e83",
534
  "metadata": {},
535
  "source": [
536
  "### Evaluation Metrics"
@@ -538,7 +538,7 @@
538
  },
539
  {
540
  "cell_type": "markdown",
541
- "id": "fd0c691d",
542
  "metadata": {},
543
  "source": [
544
  "We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing \n",
@@ -548,7 +548,7 @@
548
  {
549
  "cell_type": "code",
550
  "execution_count": 14,
551
- "id": "3044e87b",
552
  "metadata": {},
553
  "outputs": [],
554
  "source": [
@@ -559,7 +559,7 @@
559
  },
560
  {
561
  "cell_type": "markdown",
562
- "id": "b851965a",
563
  "metadata": {},
564
  "source": [
565
  "We then simply have to define a function that takes our model \n",
@@ -577,7 +577,7 @@
577
  {
578
  "cell_type": "code",
579
  "execution_count": 15,
580
- "id": "5d367fe0",
581
  "metadata": {},
582
  "outputs": [],
583
  "source": [
@@ -609,7 +609,7 @@
609
  },
610
  {
611
  "cell_type": "markdown",
612
- "id": "3335d64a",
613
  "metadata": {},
614
  "source": [
615
  "### Load a Pre-Trained Checkpoint"
@@ -617,7 +617,7 @@
617
  },
618
  {
619
  "cell_type": "markdown",
620
- "id": "70cffe1a",
621
  "metadata": {},
622
  "source": [
623
  "Now let's load the pre-trained Whisper `small` checkpoint. Again, this \n",
@@ -627,7 +627,7 @@
627
  {
628
  "cell_type": "code",
629
  "execution_count": 16,
630
- "id": "377f7eb4",
631
  "metadata": {},
632
  "outputs": [],
633
  "source": [
@@ -638,7 +638,7 @@
638
  },
639
  {
640
  "cell_type": "markdown",
641
- "id": "b1125075",
642
  "metadata": {},
643
  "source": [
644
  "Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:"
@@ -647,7 +647,7 @@
647
  {
648
  "cell_type": "code",
649
  "execution_count": 17,
650
- "id": "fb0955c6",
651
  "metadata": {},
652
  "outputs": [],
653
  "source": [
@@ -658,7 +658,7 @@
658
  },
659
  {
660
  "cell_type": "markdown",
661
- "id": "21eb2e5a",
662
  "metadata": {},
663
  "source": [
664
  "### Define the Training Configuration"
@@ -666,7 +666,7 @@
666
  },
667
  {
668
  "cell_type": "markdown",
669
- "id": "0e97f32d",
670
  "metadata": {},
671
  "source": [
672
  "In the final step, we define all the parameters related to training. Here, you can set the `max_steps` to train for longer. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments)."
@@ -675,7 +675,7 @@
675
  {
676
  "cell_type": "code",
677
  "execution_count": 18,
678
- "id": "1a3a6612",
679
  "metadata": {},
680
  "outputs": [],
681
  "source": [
@@ -707,7 +707,7 @@
707
  },
708
  {
709
  "cell_type": "markdown",
710
- "id": "0bded5c7",
711
  "metadata": {},
712
  "source": [
713
  "**Note**: if one does not want to upload the model checkpoints to the Hub, \n",
@@ -716,7 +716,7 @@
716
  },
717
  {
718
  "cell_type": "markdown",
719
- "id": "6c7feb58",
720
  "metadata": {},
721
  "source": [
722
  "We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch."
@@ -725,7 +725,7 @@
725
  {
726
  "cell_type": "code",
727
  "execution_count": 19,
728
- "id": "d17f4ee1",
729
  "metadata": {},
730
  "outputs": [],
731
  "source": [
@@ -744,7 +744,7 @@
744
  },
745
  {
746
  "cell_type": "markdown",
747
- "id": "0b679650",
748
  "metadata": {},
749
  "source": [
750
  "We can forward the training arguments to the 🤗 Trainer along with our model,\n",
@@ -754,7 +754,7 @@
754
  {
755
  "cell_type": "code",
756
  "execution_count": 20,
757
- "id": "274bc709",
758
  "metadata": {},
759
  "outputs": [
760
  {
@@ -786,7 +786,7 @@
786
  },
787
  {
788
  "cell_type": "markdown",
789
- "id": "932e7d2e",
790
  "metadata": {},
791
  "source": [
792
  "We'll save the model and processor to the output directory before training:"
@@ -795,7 +795,7 @@
795
  {
796
  "cell_type": "code",
797
  "execution_count": 21,
798
- "id": "9413aba6",
799
  "metadata": {},
800
  "outputs": [
801
  {
@@ -818,7 +818,7 @@
818
  },
819
  {
820
  "cell_type": "markdown",
821
- "id": "552289a8",
822
  "metadata": {},
823
  "source": [
824
  "### Training"
@@ -826,7 +826,7 @@
826
  },
827
  {
828
  "cell_type": "markdown",
829
- "id": "48ca9e46",
830
  "metadata": {},
831
  "source": [
832
  "Training will take approximately 5-10 hours depending on your GPU. The peak GPU memory for the given training configuration is approximately 36GB. \n",
@@ -840,8 +840,8 @@
840
  },
841
  {
842
  "cell_type": "code",
843
- "execution_count": null,
844
- "id": "8401cbc1",
845
  "metadata": {},
846
  "outputs": [
847
  {
@@ -868,8 +868,8 @@
868
  "\n",
869
  " <div>\n",
870
  " \n",
871
- " <progress value='2001' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
872
- " [2000/2000 5:52:53, Epoch 2.12/9223372036854775807]\n",
873
  " </div>\n",
874
  " <table border=\"1\" class=\"dataframe\">\n",
875
  " <thead>\n",
@@ -877,9 +877,16 @@
877
  " <th>Step</th>\n",
878
  " <th>Training Loss</th>\n",
879
  " <th>Validation Loss</th>\n",
 
880
  " </tr>\n",
881
  " </thead>\n",
882
  " <tbody>\n",
 
 
 
 
 
 
883
  " </tbody>\n",
884
  "</table><p>"
885
  ],
@@ -10682,8 +10689,50 @@
10682
  " \"transformers_version\": \"4.26.0.dev0\",\n",
10683
  " \"use_cache\": false\n",
10684
  "}\n",
10685
- "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10686
  ]
 
 
 
 
 
 
 
 
 
 
10687
  }
10688
  ],
10689
  "source": [
@@ -10692,7 +10741,7 @@
10692
  },
10693
  {
10694
  "cell_type": "markdown",
10695
- "id": "128c5b2a",
10696
  "metadata": {
10697
  "pycharm": {
10698
  "name": "#%% md\n"
@@ -10704,7 +10753,7 @@
10704
  },
10705
  {
10706
  "cell_type": "markdown",
10707
- "id": "33c489c4",
10708
  "metadata": {},
10709
  "source": [
10710
  "We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate key-word arguments (kwargs):"
@@ -10712,8 +10761,8 @@
10712
  },
10713
  {
10714
  "cell_type": "code",
10715
- "execution_count": null,
10716
- "id": "f54e6c30",
10717
  "metadata": {},
10718
  "outputs": [],
10719
  "source": [
@@ -10730,7 +10779,7 @@
10730
  },
10731
  {
10732
  "cell_type": "markdown",
10733
- "id": "bef7e4be",
10734
  "metadata": {},
10735
  "source": [
10736
  "The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command:"
@@ -10739,9 +10788,53 @@
10739
  {
10740
  "cell_type": "code",
10741
  "execution_count": null,
10742
- "id": "f1ada39d",
10743
  "metadata": {},
10744
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10745
  "source": [
10746
  "trainer.push_to_hub(**kwargs)"
10747
  ]
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "7ee06e76",
6
  "metadata": {},
7
  "source": [
8
  "# Fine-Tune Whisper With 🤗 Transformers and Streaming Mode"
 
10
  },
11
  {
12
  "cell_type": "markdown",
13
+ "id": "e840bc77",
14
  "metadata": {},
15
  "source": [
16
  "In this Colab, we present a step-by-step guide on fine-tuning Whisper with Hugging Face 🤗 Transformers on 400 hours of speech data! Using streaming mode, we'll show how you can train a speech recongition model on any dataset, irrespective of size. With streaming mode, storage requirements are no longer a consideration: you can train a model on whatever dataset you want, even if it's download size exceeds your devices disk space. How can this be possible? It simply seems too good to be true! Well, rest assured it's not 😉 Carry on reading to find out more."
 
18
  },
19
  {
20
  "cell_type": "markdown",
21
+ "id": "a2ab74b4",
22
  "metadata": {},
23
  "source": [
24
  "## Introduction"
 
26
  },
27
  {
28
  "cell_type": "markdown",
29
+ "id": "ebedf365",
30
  "metadata": {},
31
  "source": [
32
  "Speech recognition datasets are large. A typical speech dataset consists of approximately 100 hours of audio-transcription data, requiring upwards of 130GB of storage space for download and preparation. For most ASR researchers, this is already at the upper limit of what is feasible for disk space. So what happens when we want to train on a larger dataset? The full [LibriSpeech](https://huggingface.co/datasets/librispeech_asr) dataset consists of 960 hours of audio data. Kensho's [SPGISpeech](https://huggingface.co/datasets/kensho/spgispeech) contains 5,000 hours of audio data. ML Commons [People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech) contains **30,000+** hours of audio data! Do we need to bite the bullet and buy additional storage? Or is there a way we can train on all of these datasets with no disk drive requirements?\n",
 
42
  },
43
  {
44
  "cell_type": "markdown",
45
+ "id": "76a56c73",
46
  "metadata": {},
47
  "source": [
48
  "<figure>\n",
 
53
  },
54
  {
55
  "cell_type": "markdown",
56
+ "id": "ccdf81af",
57
  "metadata": {},
58
  "source": [
59
  "This notebook provides a guide to fine-tuning on the task of _speech recognition_, which involves learning a\n",
 
92
  },
93
  {
94
  "cell_type": "markdown",
95
+ "id": "76a78106",
96
  "metadata": {},
97
  "source": [
98
  "## Load Dataset with Streaming"
 
100
  },
101
  {
102
  "cell_type": "markdown",
103
+ "id": "24a396ef",
104
  "metadata": {},
105
  "source": [
106
  "This is where the magic happens! We'll first write a wrapper function around 🤗 Datasets `load_dataset` method. This function downloads the required splits using streaming mode by forcing `streaming=True` in the `load_dataset` method. Multiple splits can be combined (interleaved) by concatenating them with the \"+\" symbol when specifying the split name, e.g. `split=train+validation` will return a single split with the training and validation splits interleaved together. The function has the same arguments and key-word arguments as 🤗 Datasets `load_dataset` method, so we can use it in exactly the same way!"
 
109
  {
110
  "cell_type": "code",
111
  "execution_count": 1,
112
+ "id": "5b120832",
113
  "metadata": {},
114
  "outputs": [],
115
  "source": [
 
130
  },
131
  {
132
  "cell_type": "markdown",
133
+ "id": "a4c01c76",
134
  "metadata": {},
135
  "source": [
136
  "We'll train our system on the Spanish split of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). We can see how much training data we have by viewing the [language page](https://commonvoice.mozilla.org/en/datasets) on the Common Voice website. The Spanish split has over 400 hours of labelled training data - that's enourmous! More than we could ever fit on a Google Colab or a standard workstation. But with streaming mode, we'll only download data as and when we need it, making training on this dataset possible!\n",
 
143
  {
144
  "cell_type": "code",
145
  "execution_count": 2,
146
+ "id": "046af31a",
147
  "metadata": {},
148
  "outputs": [],
149
  "source": [
 
157
  },
158
  {
159
  "cell_type": "markdown",
160
+ "id": "10730e4d",
161
  "metadata": {},
162
  "source": [
163
  "## Prepare Processor and Pre-Process Data"
 
165
  },
166
  {
167
  "cell_type": "markdown",
168
+ "id": "7f1a10c5",
169
  "metadata": {},
170
  "source": [
171
  "The ASR pipeline can be de-composed into three stages: \n",
 
186
  {
187
  "cell_type": "code",
188
  "execution_count": 3,
189
+ "id": "2c491742",
190
  "metadata": {},
191
  "outputs": [],
192
  "source": [
 
197
  },
198
  {
199
  "cell_type": "markdown",
200
+ "id": "cce2486f",
201
  "metadata": {},
202
  "source": [
203
  "### Pre-Process Data"
 
205
  },
206
  {
207
  "cell_type": "markdown",
208
+ "id": "e089f77e",
209
  "metadata": {},
210
  "source": [
211
  "Let's have a look at the dataset features. Pay particular attention to the `\"audio\"` column - this details the sampling rate of our audio inputs:"
 
214
  {
215
  "cell_type": "code",
216
  "execution_count": 4,
217
+ "id": "558b546f",
218
  "metadata": {},
219
  "outputs": [
220
  {
 
244
  },
245
  {
246
  "cell_type": "markdown",
247
+ "id": "fdea2685",
248
  "metadata": {},
249
  "source": [
250
  "Since our input audio is sampled at 48kHz, we need to _downsample_ it to\n",
 
260
  {
261
  "cell_type": "code",
262
  "execution_count": 5,
263
+ "id": "fb5ab66e",
264
  "metadata": {},
265
  "outputs": [],
266
  "source": [
 
271
  },
272
  {
273
  "cell_type": "markdown",
274
+ "id": "a0fe8d76",
275
  "metadata": {},
276
  "source": [
277
  "We'll define our pre-processing strategy. We advise that you **do not** lower-case the transcriptions or remove punctuation unless mixing different datasets. This will enable you to fine-tune Whisper models that can predict punctuation and casing. Later, you will see how we can evaluate the predictions without punctuation or casing, so that the models benefit from the WER improvement obtained by normalising the transcriptions while still predicting fully formatted transcriptions."
 
280
  {
281
  "cell_type": "code",
282
  "execution_count": 6,
283
+ "id": "e458e800",
284
  "metadata": {},
285
  "outputs": [],
286
  "source": [
 
294
  },
295
  {
296
  "cell_type": "markdown",
297
+ "id": "1c68a2f7",
298
  "metadata": {},
299
  "source": [
300
  "Now we can write a function to prepare our data ready for the model:\n",
 
307
  {
308
  "cell_type": "code",
309
  "execution_count": 7,
310
+ "id": "fb7ff9e0",
311
  "metadata": {},
312
  "outputs": [],
313
  "source": [
 
334
  },
335
  {
336
  "cell_type": "markdown",
337
+ "id": "f7f8f446",
338
  "metadata": {},
339
  "source": [
340
  "We can apply the data preparation function to all of our training examples using 🤗 Datasets' `.map` method. We'll remove all of the columns from the raw training data, leaving just the `input_features` and `labels` defined in the `prepare_dataset` function:"
 
343
  {
344
  "cell_type": "code",
345
  "execution_count": 8,
346
+ "id": "20ad4d6d",
347
  "metadata": {},
348
  "outputs": [],
349
  "source": [
 
352
  },
353
  {
354
  "cell_type": "markdown",
355
+ "id": "522d4732",
356
  "metadata": {},
357
  "source": [
358
  "We can now define how we shuffle the data in the train split. The size of the subset we load is set by the variable `buffer_size`. You can increase or decrease this depending on your memory constraints. In this example, the `buffer_size` is set to 500, meaning 500 samples are loaded before shuffling across the subset. The larger we set this value, the closer to True offline shuffling. The `seed` is set for reproducibility:"
 
361
  {
362
  "cell_type": "code",
363
  "execution_count": 9,
364
+ "id": "149e1f43",
365
  "metadata": {},
366
  "outputs": [],
367
  "source": [
 
373
  },
374
  {
375
  "cell_type": "markdown",
376
+ "id": "3352d9b2",
377
  "metadata": {},
378
  "source": [
379
  "Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by the Whisper feature-extractor which could affect the stability of training. We define a function that returns `True` for samples that are less than 30s, and `False` for those that are longer:"
 
382
  {
383
  "cell_type": "code",
384
  "execution_count": 10,
385
+ "id": "81bf2f51",
386
  "metadata": {},
387
  "outputs": [],
388
  "source": [
 
394
  },
395
  {
396
  "cell_type": "markdown",
397
+ "id": "77456cf9",
398
  "metadata": {},
399
  "source": [
400
  "We apply our filter function to all samples of our training dataset through 🤗 Datasets' `.filter` method:"
 
403
  {
404
  "cell_type": "code",
405
  "execution_count": 11,
406
+ "id": "3d13426b",
407
  "metadata": {},
408
  "outputs": [],
409
  "source": [
 
415
  },
416
  {
417
  "cell_type": "markdown",
418
+ "id": "482d7981",
419
  "metadata": {},
420
  "source": [
421
  "## Training and Evaluation"
 
423
  },
424
  {
425
  "cell_type": "markdown",
426
+ "id": "d23e3a80",
427
  "metadata": {},
428
  "source": [
429
  "Now that we've prepared our data, we're ready to dive into the training pipeline. \n",
 
441
  },
442
  {
443
  "cell_type": "markdown",
444
+ "id": "67d93ae7",
445
  "metadata": {},
446
  "source": [
447
  "### Define a Data Collator"
 
449
  },
450
  {
451
  "cell_type": "markdown",
452
+ "id": "8c180b86",
453
  "metadata": {},
454
  "source": [
455
  "The data collator for a sequence-to-sequence speech model is unique in the sense that it \n",
 
473
  {
474
  "cell_type": "code",
475
  "execution_count": 12,
476
+ "id": "aa4f7d60",
477
  "metadata": {},
478
  "outputs": [],
479
  "source": [
 
512
  },
513
  {
514
  "cell_type": "markdown",
515
+ "id": "d9393af2",
516
  "metadata": {},
517
  "source": [
518
  "Let's initialise the data collator we've just defined:"
 
521
  {
522
  "cell_type": "code",
523
  "execution_count": 13,
524
+ "id": "69ea823d",
525
  "metadata": {},
526
  "outputs": [],
527
  "source": [
 
530
  },
531
  {
532
  "cell_type": "markdown",
533
+ "id": "18ec9be0",
534
  "metadata": {},
535
  "source": [
536
  "### Evaluation Metrics"
 
538
  },
539
  {
540
  "cell_type": "markdown",
541
+ "id": "32ab245e",
542
  "metadata": {},
543
  "source": [
544
  "We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing \n",
 
548
  {
549
  "cell_type": "code",
550
  "execution_count": 14,
551
+ "id": "73dee279",
552
  "metadata": {},
553
  "outputs": [],
554
  "source": [
 
559
  },
560
  {
561
  "cell_type": "markdown",
562
+ "id": "f3951d48",
563
  "metadata": {},
564
  "source": [
565
  "We then simply have to define a function that takes our model \n",
 
577
  {
578
  "cell_type": "code",
579
  "execution_count": 15,
580
+ "id": "6974fa39",
581
  "metadata": {},
582
  "outputs": [],
583
  "source": [
 
609
  },
610
  {
611
  "cell_type": "markdown",
612
+ "id": "95d6936b",
613
  "metadata": {},
614
  "source": [
615
  "### Load a Pre-Trained Checkpoint"
 
617
  },
618
  {
619
  "cell_type": "markdown",
620
+ "id": "f5213c1d",
621
  "metadata": {},
622
  "source": [
623
  "Now let's load the pre-trained Whisper `small` checkpoint. Again, this \n",
 
627
  {
628
  "cell_type": "code",
629
  "execution_count": 16,
630
+ "id": "7d22d4d7",
631
  "metadata": {},
632
  "outputs": [],
633
  "source": [
 
638
  },
639
  {
640
  "cell_type": "markdown",
641
+ "id": "8ada21a6",
642
  "metadata": {},
643
  "source": [
644
  "Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:"
 
647
  {
648
  "cell_type": "code",
649
  "execution_count": 17,
650
+ "id": "5c1266ac",
651
  "metadata": {},
652
  "outputs": [],
653
  "source": [
 
658
  },
659
  {
660
  "cell_type": "markdown",
661
+ "id": "628c7f4e",
662
  "metadata": {},
663
  "source": [
664
  "### Define the Training Configuration"
 
666
  },
667
  {
668
  "cell_type": "markdown",
669
+ "id": "309768d9",
670
  "metadata": {},
671
  "source": [
672
  "In the final step, we define all the parameters related to training. Here, you can set the `max_steps` to train for longer. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments)."
 
675
  {
676
  "cell_type": "code",
677
  "execution_count": 18,
678
+ "id": "253e7071",
679
  "metadata": {},
680
  "outputs": [],
681
  "source": [
 
707
  },
708
  {
709
  "cell_type": "markdown",
710
+ "id": "e9e4f70c",
711
  "metadata": {},
712
  "source": [
713
  "**Note**: if one does not want to upload the model checkpoints to the Hub, \n",
 
716
  },
717
  {
718
  "cell_type": "markdown",
719
+ "id": "a2d970ab",
720
  "metadata": {},
721
  "source": [
722
  "We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch."
 
725
  {
726
  "cell_type": "code",
727
  "execution_count": 19,
728
+ "id": "8c7821b6",
729
  "metadata": {},
730
  "outputs": [],
731
  "source": [
 
744
  },
745
  {
746
  "cell_type": "markdown",
747
+ "id": "1a52a229",
748
  "metadata": {},
749
  "source": [
750
  "We can forward the training arguments to the 🤗 Trainer along with our model,\n",
 
754
  {
755
  "cell_type": "code",
756
  "execution_count": 20,
757
+ "id": "1d86db55",
758
  "metadata": {},
759
  "outputs": [
760
  {
 
786
  },
787
  {
788
  "cell_type": "markdown",
789
+ "id": "e8d4699b",
790
  "metadata": {},
791
  "source": [
792
  "We'll save the model and processor to the output directory before training:"
 
795
  {
796
  "cell_type": "code",
797
  "execution_count": 21,
798
+ "id": "8ec1bd4e",
799
  "metadata": {},
800
  "outputs": [
801
  {
 
818
  },
819
  {
820
  "cell_type": "markdown",
821
+ "id": "cedff338",
822
  "metadata": {},
823
  "source": [
824
  "### Training"
 
826
  },
827
  {
828
  "cell_type": "markdown",
829
+ "id": "d5ac0619",
830
  "metadata": {},
831
  "source": [
832
  "Training will take approximately 5-10 hours depending on your GPU. The peak GPU memory for the given training configuration is approximately 36GB. \n",
 
840
  },
841
  {
842
  "cell_type": "code",
843
+ "execution_count": 22,
844
+ "id": "c66858bf",
845
  "metadata": {},
846
  "outputs": [
847
  {
 
868
  "\n",
869
  " <div>\n",
870
  " \n",
871
+ " <progress value='2000' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
872
+ " [2000/2000 9:44:45, Epoch 2/9223372036854775807]\n",
873
  " </div>\n",
874
  " <table border=\"1\" class=\"dataframe\">\n",
875
  " <thead>\n",
 
877
  " <th>Step</th>\n",
878
  " <th>Training Loss</th>\n",
879
  " <th>Validation Loss</th>\n",
880
+ " <th>Wer</th>\n",
881
  " </tr>\n",
882
  " </thead>\n",
883
  " <tbody>\n",
884
+ " <tr>\n",
885
+ " <td>2000</td>\n",
886
+ " <td>0.115100</td>\n",
887
+ " <td>0.385472</td>\n",
888
+ " <td>73.658667</td>\n",
889
+ " </tr>\n",
890
  " </tbody>\n",
891
  "</table><p>"
892
  ],
 
10689
  " \"transformers_version\": \"4.26.0.dev0\",\n",
10690
  " \"use_cache\": false\n",
10691
  "}\n",
10692
+ "\n",
10693
+ "Generate config GenerationConfig {\n",
10694
+ " \"begin_suppress_tokens\": [\n",
10695
+ " 220,\n",
10696
+ " 50257\n",
10697
+ " ],\n",
10698
+ " \"bos_token_id\": 50257,\n",
10699
+ " \"decoder_start_token_id\": 50258,\n",
10700
+ " \"eos_token_id\": 50257,\n",
10701
+ " \"max_length\": 448,\n",
10702
+ " \"pad_token_id\": 50257,\n",
10703
+ " \"suppress_tokens\": [],\n",
10704
+ " \"transformers_version\": \"4.26.0.dev0\",\n",
10705
+ " \"use_cache\": false\n",
10706
+ "}\n",
10707
+ "\n",
10708
+ "Saving model checkpoint to ./checkpoint-2000\n",
10709
+ "Configuration saved in ./checkpoint-2000/config.json\n",
10710
+ "Model weights saved in ./checkpoint-2000/pytorch_model.bin\n",
10711
+ "Feature extractor saved in ./checkpoint-2000/preprocessor_config.json\n",
10712
+ "tokenizer config file saved in ./checkpoint-2000/tokenizer_config.json\n",
10713
+ "Special tokens file saved in ./checkpoint-2000/special_tokens_map.json\n",
10714
+ "added tokens file saved in ./checkpoint-2000/added_tokens.json\n",
10715
+ "Feature extractor saved in ./preprocessor_config.json\n",
10716
+ "tokenizer config file saved in ./tokenizer_config.json\n",
10717
+ "Special tokens file saved in ./special_tokens_map.json\n",
10718
+ "added tokens file saved in ./added_tokens.json\n",
10719
+ "\n",
10720
+ "\n",
10721
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
10722
+ "\n",
10723
+ "\n",
10724
+ "Loading best model from ./checkpoint-2000 (score: 73.65866666666666).\n"
10725
  ]
10726
+ },
10727
+ {
10728
+ "data": {
10729
+ "text/plain": [
10730
+ "TrainOutput(global_step=2000, training_loss=0.3622547821998596, metrics={'train_runtime': 35158.6606, 'train_samples_per_second': 1.82, 'train_steps_per_second': 0.057, 'total_flos': 1.847581449928704e+19, 'train_loss': 0.3622547821998596, 'epoch': 2.12})"
10731
+ ]
10732
+ },
10733
+ "execution_count": 22,
10734
+ "metadata": {},
10735
+ "output_type": "execute_result"
10736
  }
10737
  ],
10738
  "source": [
 
10741
  },
10742
  {
10743
  "cell_type": "markdown",
10744
+ "id": "41f44439",
10745
  "metadata": {
10746
  "pycharm": {
10747
  "name": "#%% md\n"
 
10753
  },
10754
  {
10755
  "cell_type": "markdown",
10756
+ "id": "2c05ac2d",
10757
  "metadata": {},
10758
  "source": [
10759
  "We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate key-word arguments (kwargs):"
 
10761
  },
10762
  {
10763
  "cell_type": "code",
10764
+ "execution_count": 23,
10765
+ "id": "005fbee5",
10766
  "metadata": {},
10767
  "outputs": [],
10768
  "source": [
 
10779
  },
10780
  {
10781
  "cell_type": "markdown",
10782
+ "id": "a92005f6",
10783
  "metadata": {},
10784
  "source": [
10785
  "The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command:"
 
10788
  {
10789
  "cell_type": "code",
10790
  "execution_count": null,
10791
+ "id": "9aa840fa",
10792
  "metadata": {},
10793
+ "outputs": [
10794
+ {
10795
+ "name": "stderr",
10796
+ "output_type": "stream",
10797
+ "text": [
10798
+ "Saving model checkpoint to ./\n",
10799
+ "Configuration saved in ./config.json\n",
10800
+ "Model weights saved in ./pytorch_model.bin\n",
10801
+ "Feature extractor saved in ./preprocessor_config.json\n",
10802
+ "tokenizer config file saved in ./tokenizer_config.json\n",
10803
+ "Special tokens file saved in ./special_tokens_map.json\n",
10804
+ "added tokens file saved in ./added_tokens.json\n",
10805
+ "Several commits (2) will be pushed upstream.\n",
10806
+ "The progress bars may be unreliable.\n"
10807
+ ]
10808
+ },
10809
+ {
10810
+ "data": {
10811
+ "application/vnd.jupyter.widget-view+json": {
10812
+ "model_id": "8e232938a5444cb2810ce5010da7951a",
10813
+ "version_major": 2,
10814
+ "version_minor": 0
10815
+ },
10816
+ "text/plain": [
10817
+ "Upload file pytorch_model.bin: 0%| | 32.0k/922M [00:00<?, ?B/s]"
10818
+ ]
10819
+ },
10820
+ "metadata": {},
10821
+ "output_type": "display_data"
10822
+ },
10823
+ {
10824
+ "data": {
10825
+ "application/vnd.jupyter.widget-view+json": {
10826
+ "model_id": "47d78028cf7445cab5aca33b58f58b24",
10827
+ "version_major": 2,
10828
+ "version_minor": 0
10829
+ },
10830
+ "text/plain": [
10831
+ "Upload file runs/Dec21_00-50-57_c4dc565b4234/events.out.tfevents.1671583868.c4dc565b4234.3974847.0: 100%|#####…"
10832
+ ]
10833
+ },
10834
+ "metadata": {},
10835
+ "output_type": "display_data"
10836
+ }
10837
+ ],
10838
  "source": [
10839
  "trainer.push_to_hub(**kwargs)"
10840
  ]