sanchit-gandhi HF staff commited on
Commit
4ee7109
1 Parent(s): 62bd796

Push to Hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __pycache__/process_asr_text_tokenizer.cpython-39.pyc +0 -0
  3. all_results.json +21 -0
  4. check_bnb_install.py +19 -0
  5. checkpoint-100000/optimizer.pt +3 -0
  6. checkpoint-100000/rng_state.pth +3 -0
  7. checkpoint-100000/scheduler.pt +3 -0
  8. checkpoint-100000/stt_en_conformer_transducer_xlarge.nemo +3 -0
  9. checkpoint-100000/trainer_state.json +0 -0
  10. checkpoint-100000/training_args.bin +3 -0
  11. checkpoint-20000/optimizer.pt +3 -0
  12. checkpoint-20000/rng_state.pth +3 -0
  13. checkpoint-20000/scheduler.pt +3 -0
  14. checkpoint-20000/stt_en_conformer_transducer_xlarge.nemo +3 -0
  15. checkpoint-20000/trainer_state.json +2425 -0
  16. checkpoint-20000/training_args.bin +3 -0
  17. checkpoint-40000/optimizer.pt +3 -0
  18. checkpoint-40000/rng_state.pth +3 -0
  19. checkpoint-40000/scheduler.pt +3 -0
  20. checkpoint-40000/stt_en_conformer_transducer_xlarge.nemo +3 -0
  21. checkpoint-40000/trainer_state.json +0 -0
  22. checkpoint-40000/training_args.bin +3 -0
  23. checkpoint-60000/optimizer.pt +3 -0
  24. checkpoint-60000/rng_state.pth +3 -0
  25. checkpoint-60000/scheduler.pt +3 -0
  26. checkpoint-60000/stt_en_conformer_transducer_xlarge.nemo +3 -0
  27. checkpoint-60000/trainer_state.json +0 -0
  28. checkpoint-60000/training_args.bin +3 -0
  29. checkpoint-80000/optimizer.pt +3 -0
  30. checkpoint-80000/rng_state.pth +3 -0
  31. checkpoint-80000/scheduler.pt +3 -0
  32. checkpoint-80000/stt_en_conformer_transducer_xlarge.nemo +3 -0
  33. checkpoint-80000/trainer_state.json +0 -0
  34. checkpoint-80000/training_args.bin +3 -0
  35. conf/conformer_transducer_bpe_dummy.yaml +192 -0
  36. conf/conformer_transducer_bpe_large.yaml +212 -0
  37. conf/conformer_transducer_bpe_xlarge.yaml +196 -0
  38. conf/contextnet_rnnt.yaml +472 -0
  39. conf/contextnet_rnnt_dummy.yaml +197 -0
  40. eval_results.json +9 -0
  41. models/__init__.py +1 -0
  42. models/__pycache__/__init__.cpython-39.pyc +0 -0
  43. models/__pycache__/modeling_rnnt.cpython-39.pyc +0 -0
  44. models/modeling_rnnt.py +115 -0
  45. process_asr_text_tokenizer.py +221 -0
  46. requirements.txt +7 -0
  47. run_ami.sh +38 -0
  48. run_speech_recognition_rnnt.py +935 -0
  49. scripts/run_batch_size_sweep.yaml +61 -0
  50. scripts/run_common_voice_9.sh +38 -0
.gitattributes CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ *.nemo filter=lfs diff=lfs merge=lfs -text
__pycache__/process_asr_text_tokenizer.cpython-39.pyc ADDED
Binary file (3.95 kB). View file
 
all_results.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 7.38,
3
+ "eval_loss": 8.706663131713867,
4
+ "eval_runtime": 970.2156,
5
+ "eval_samples": 13098,
6
+ "eval_samples_per_second": 13.5,
7
+ "eval_steps_per_second": 3.376,
8
+ "eval_wer": 0.20430683297635546,
9
+ "test_cer": 0.08093431359873023,
10
+ "test_loss": 5.917323112487793,
11
+ "test_runtime": 946.7263,
12
+ "test_samples": 12643,
13
+ "test_samples_per_second": 13.354,
14
+ "test_steps_per_second": 3.339,
15
+ "test_wer": 0.17709850666607363,
16
+ "train_loss": 10.025987887954182,
17
+ "train_runtime": 56856.134,
18
+ "train_samples": 108449,
19
+ "train_samples_per_second": 14.077,
20
+ "train_steps_per_second": 1.76
21
+ }
check_bnb_install.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bitsandbytes as bnb
2
+ import torch
3
+
4
+ p = torch.nn.Parameter(torch.rand(10, 10).cuda())
5
+ a = torch.rand(10, 10).cuda()
6
+
7
+ p1 = p.data.sum().item()
8
+
9
+ adam = bnb.optim.Adam([p])
10
+
11
+ out = a * p
12
+ loss = out.sum()
13
+ loss.backward()
14
+ adam.step()
15
+
16
+ p2 = p.data.sum().item()
17
+
18
+ assert p1 != p2
19
+ print('bnb: installed successfully!')
checkpoint-100000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75ca838bfd7e8d7e8ebc431190243148d186a5f1ed5cd674b751f6079710ab95
3
+ size 5154565443
checkpoint-100000/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714299d2503f04c887174fcb2c5995d31c2a8dd3d887f5907696d7a91cbcb1a
3
+ size 14503
checkpoint-100000/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591846b441d543caac3afc7202fecfc43bf20ba0c611a291457e9c81cc395399
3
+ size 623
checkpoint-100000/stt_en_conformer_transducer_xlarge.nemo ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4ddd41c1adabfce64125bbf639cadda2f044651386a1060440b2e49caea9f52
3
+ size 2577971200
checkpoint-100000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-100000/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b64c669f66dd7a2e54d3001ce7e31c26cc60dd58136e8ce90e6055bd0ae15eb
3
+ size 3503
checkpoint-20000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3f20cc328e6cf018f92f3b71e11bf4a9364f5a247ee5d99d4a62354ede6a516
3
+ size 5154563651
checkpoint-20000/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fb3410dde03074fae133541463bfebd7d0708693d5ffa17edc4fe4974c0f7eb
3
+ size 14503
checkpoint-20000/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caeda3b27b783dbb84d9e4d82bc20bd764fb8fbed5023345d4c45d753ffa45b0
3
+ size 623
checkpoint-20000/stt_en_conformer_transducer_xlarge.nemo ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06c6f31b89b77d8eaf30394215a6001e812460139f4276d335e97c10cc0b632e
3
+ size 2577971200
checkpoint-20000/trainer_state.json ADDED
@@ -0,0 +1,2425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 1.4752526370140886,
5
+ "global_step": 20000,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.0,
12
+ "learning_rate": 1e-05,
13
+ "loss": 178.9465,
14
+ "step": 50
15
+ },
16
+ {
17
+ "epoch": 0.01,
18
+ "learning_rate": 2e-05,
19
+ "loss": 164.9707,
20
+ "step": 100
21
+ },
22
+ {
23
+ "epoch": 0.01,
24
+ "learning_rate": 3e-05,
25
+ "loss": 142.2782,
26
+ "step": 150
27
+ },
28
+ {
29
+ "epoch": 0.01,
30
+ "learning_rate": 4e-05,
31
+ "loss": 121.5122,
32
+ "step": 200
33
+ },
34
+ {
35
+ "epoch": 0.02,
36
+ "learning_rate": 5e-05,
37
+ "loss": 91.8622,
38
+ "step": 250
39
+ },
40
+ {
41
+ "epoch": 0.02,
42
+ "learning_rate": 6e-05,
43
+ "loss": 82.2062,
44
+ "step": 300
45
+ },
46
+ {
47
+ "epoch": 0.03,
48
+ "learning_rate": 7e-05,
49
+ "loss": 72.6893,
50
+ "step": 350
51
+ },
52
+ {
53
+ "epoch": 0.03,
54
+ "learning_rate": 8e-05,
55
+ "loss": 71.8709,
56
+ "step": 400
57
+ },
58
+ {
59
+ "epoch": 0.03,
60
+ "learning_rate": 9e-05,
61
+ "loss": 69.9995,
62
+ "step": 450
63
+ },
64
+ {
65
+ "epoch": 0.04,
66
+ "learning_rate": 0.0001,
67
+ "loss": 70.6458,
68
+ "step": 500
69
+ },
70
+ {
71
+ "epoch": 0.04,
72
+ "learning_rate": 9.994977448744865e-05,
73
+ "loss": 73.9929,
74
+ "step": 550
75
+ },
76
+ {
77
+ "epoch": 0.04,
78
+ "learning_rate": 9.989954897489729e-05,
79
+ "loss": 66.52,
80
+ "step": 600
81
+ },
82
+ {
83
+ "epoch": 0.05,
84
+ "learning_rate": 9.984932346234594e-05,
85
+ "loss": 65.8947,
86
+ "step": 650
87
+ },
88
+ {
89
+ "epoch": 0.05,
90
+ "learning_rate": 9.979909794979458e-05,
91
+ "loss": 62.5809,
92
+ "step": 700
93
+ },
94
+ {
95
+ "epoch": 0.06,
96
+ "learning_rate": 9.974887243724323e-05,
97
+ "loss": 61.212,
98
+ "step": 750
99
+ },
100
+ {
101
+ "epoch": 0.06,
102
+ "learning_rate": 9.969864692469187e-05,
103
+ "loss": 68.2408,
104
+ "step": 800
105
+ },
106
+ {
107
+ "epoch": 0.06,
108
+ "learning_rate": 9.964842141214051e-05,
109
+ "loss": 61.5308,
110
+ "step": 850
111
+ },
112
+ {
113
+ "epoch": 0.07,
114
+ "learning_rate": 9.959819589958916e-05,
115
+ "loss": 58.9116,
116
+ "step": 900
117
+ },
118
+ {
119
+ "epoch": 0.07,
120
+ "learning_rate": 9.95479703870378e-05,
121
+ "loss": 60.0702,
122
+ "step": 950
123
+ },
124
+ {
125
+ "epoch": 0.07,
126
+ "learning_rate": 9.949774487448646e-05,
127
+ "loss": 57.6135,
128
+ "step": 1000
129
+ },
130
+ {
131
+ "epoch": 0.08,
132
+ "learning_rate": 9.944751936193509e-05,
133
+ "loss": 50.9231,
134
+ "step": 1050
135
+ },
136
+ {
137
+ "epoch": 0.08,
138
+ "learning_rate": 9.939729384938373e-05,
139
+ "loss": 51.187,
140
+ "step": 1100
141
+ },
142
+ {
143
+ "epoch": 0.08,
144
+ "learning_rate": 9.934706833683238e-05,
145
+ "loss": 52.1127,
146
+ "step": 1150
147
+ },
148
+ {
149
+ "epoch": 0.09,
150
+ "learning_rate": 9.929684282428102e-05,
151
+ "loss": 47.4608,
152
+ "step": 1200
153
+ },
154
+ {
155
+ "epoch": 0.09,
156
+ "learning_rate": 9.924661731172968e-05,
157
+ "loss": 51.6108,
158
+ "step": 1250
159
+ },
160
+ {
161
+ "epoch": 0.1,
162
+ "learning_rate": 9.919639179917831e-05,
163
+ "loss": 46.5874,
164
+ "step": 1300
165
+ },
166
+ {
167
+ "epoch": 0.1,
168
+ "learning_rate": 9.914616628662697e-05,
169
+ "loss": 41.4706,
170
+ "step": 1350
171
+ },
172
+ {
173
+ "epoch": 0.1,
174
+ "learning_rate": 9.90959407740756e-05,
175
+ "loss": 43.7544,
176
+ "step": 1400
177
+ },
178
+ {
179
+ "epoch": 0.11,
180
+ "learning_rate": 9.904571526152426e-05,
181
+ "loss": 44.6039,
182
+ "step": 1450
183
+ },
184
+ {
185
+ "epoch": 0.11,
186
+ "learning_rate": 9.899548974897289e-05,
187
+ "loss": 41.4384,
188
+ "step": 1500
189
+ },
190
+ {
191
+ "epoch": 0.11,
192
+ "learning_rate": 9.894526423642154e-05,
193
+ "loss": 42.8289,
194
+ "step": 1550
195
+ },
196
+ {
197
+ "epoch": 0.12,
198
+ "learning_rate": 9.889503872387019e-05,
199
+ "loss": 39.9726,
200
+ "step": 1600
201
+ },
202
+ {
203
+ "epoch": 0.12,
204
+ "learning_rate": 9.884481321131882e-05,
205
+ "loss": 43.9533,
206
+ "step": 1650
207
+ },
208
+ {
209
+ "epoch": 0.13,
210
+ "learning_rate": 9.879458769876748e-05,
211
+ "loss": 38.7605,
212
+ "step": 1700
213
+ },
214
+ {
215
+ "epoch": 0.13,
216
+ "learning_rate": 9.87443621862161e-05,
217
+ "loss": 39.5425,
218
+ "step": 1750
219
+ },
220
+ {
221
+ "epoch": 0.13,
222
+ "learning_rate": 9.869413667366476e-05,
223
+ "loss": 37.588,
224
+ "step": 1800
225
+ },
226
+ {
227
+ "epoch": 0.14,
228
+ "learning_rate": 9.86439111611134e-05,
229
+ "loss": 39.7744,
230
+ "step": 1850
231
+ },
232
+ {
233
+ "epoch": 0.14,
234
+ "learning_rate": 9.859368564856205e-05,
235
+ "loss": 38.2154,
236
+ "step": 1900
237
+ },
238
+ {
239
+ "epoch": 0.14,
240
+ "learning_rate": 9.85434601360107e-05,
241
+ "loss": 35.0806,
242
+ "step": 1950
243
+ },
244
+ {
245
+ "epoch": 0.15,
246
+ "learning_rate": 9.849323462345934e-05,
247
+ "loss": 39.061,
248
+ "step": 2000
249
+ },
250
+ {
251
+ "epoch": 0.15,
252
+ "learning_rate": 9.844300911090798e-05,
253
+ "loss": 35.1544,
254
+ "step": 2050
255
+ },
256
+ {
257
+ "epoch": 0.15,
258
+ "learning_rate": 9.839278359835663e-05,
259
+ "loss": 38.123,
260
+ "step": 2100
261
+ },
262
+ {
263
+ "epoch": 0.16,
264
+ "learning_rate": 9.834255808580527e-05,
265
+ "loss": 33.1144,
266
+ "step": 2150
267
+ },
268
+ {
269
+ "epoch": 0.16,
270
+ "learning_rate": 9.829233257325392e-05,
271
+ "loss": 34.3476,
272
+ "step": 2200
273
+ },
274
+ {
275
+ "epoch": 0.17,
276
+ "learning_rate": 9.824210706070256e-05,
277
+ "loss": 29.5665,
278
+ "step": 2250
279
+ },
280
+ {
281
+ "epoch": 0.17,
282
+ "learning_rate": 9.81918815481512e-05,
283
+ "loss": 35.8756,
284
+ "step": 2300
285
+ },
286
+ {
287
+ "epoch": 0.17,
288
+ "learning_rate": 9.814165603559985e-05,
289
+ "loss": 37.2579,
290
+ "step": 2350
291
+ },
292
+ {
293
+ "epoch": 0.18,
294
+ "learning_rate": 9.809143052304849e-05,
295
+ "loss": 33.6245,
296
+ "step": 2400
297
+ },
298
+ {
299
+ "epoch": 0.18,
300
+ "learning_rate": 9.804120501049714e-05,
301
+ "loss": 35.6543,
302
+ "step": 2450
303
+ },
304
+ {
305
+ "epoch": 0.18,
306
+ "learning_rate": 9.799097949794578e-05,
307
+ "loss": 36.7847,
308
+ "step": 2500
309
+ },
310
+ {
311
+ "epoch": 0.19,
312
+ "learning_rate": 9.794075398539442e-05,
313
+ "loss": 33.463,
314
+ "step": 2550
315
+ },
316
+ {
317
+ "epoch": 0.19,
318
+ "learning_rate": 9.789052847284307e-05,
319
+ "loss": 32.2215,
320
+ "step": 2600
321
+ },
322
+ {
323
+ "epoch": 0.2,
324
+ "learning_rate": 9.784030296029171e-05,
325
+ "loss": 33.4301,
326
+ "step": 2650
327
+ },
328
+ {
329
+ "epoch": 0.2,
330
+ "learning_rate": 9.779007744774036e-05,
331
+ "loss": 29.9579,
332
+ "step": 2700
333
+ },
334
+ {
335
+ "epoch": 0.2,
336
+ "learning_rate": 9.773985193518901e-05,
337
+ "loss": 31.9141,
338
+ "step": 2750
339
+ },
340
+ {
341
+ "epoch": 0.21,
342
+ "learning_rate": 9.768962642263764e-05,
343
+ "loss": 33.2049,
344
+ "step": 2800
345
+ },
346
+ {
347
+ "epoch": 0.21,
348
+ "learning_rate": 9.763940091008629e-05,
349
+ "loss": 32.8774,
350
+ "step": 2850
351
+ },
352
+ {
353
+ "epoch": 0.21,
354
+ "learning_rate": 9.758917539753493e-05,
355
+ "loss": 29.0858,
356
+ "step": 2900
357
+ },
358
+ {
359
+ "epoch": 0.22,
360
+ "learning_rate": 9.753894988498358e-05,
361
+ "loss": 30.1145,
362
+ "step": 2950
363
+ },
364
+ {
365
+ "epoch": 0.22,
366
+ "learning_rate": 9.748872437243222e-05,
367
+ "loss": 27.6986,
368
+ "step": 3000
369
+ },
370
+ {
371
+ "epoch": 0.22,
372
+ "learning_rate": 9.743849885988087e-05,
373
+ "loss": 31.7807,
374
+ "step": 3050
375
+ },
376
+ {
377
+ "epoch": 0.23,
378
+ "learning_rate": 9.738827334732952e-05,
379
+ "loss": 30.5108,
380
+ "step": 3100
381
+ },
382
+ {
383
+ "epoch": 0.23,
384
+ "learning_rate": 9.733804783477815e-05,
385
+ "loss": 31.0909,
386
+ "step": 3150
387
+ },
388
+ {
389
+ "epoch": 0.24,
390
+ "learning_rate": 9.728782232222681e-05,
391
+ "loss": 27.9057,
392
+ "step": 3200
393
+ },
394
+ {
395
+ "epoch": 0.24,
396
+ "learning_rate": 9.723759680967544e-05,
397
+ "loss": 29.7323,
398
+ "step": 3250
399
+ },
400
+ {
401
+ "epoch": 0.24,
402
+ "learning_rate": 9.71873712971241e-05,
403
+ "loss": 29.7527,
404
+ "step": 3300
405
+ },
406
+ {
407
+ "epoch": 0.25,
408
+ "learning_rate": 9.713714578457273e-05,
409
+ "loss": 29.1442,
410
+ "step": 3350
411
+ },
412
+ {
413
+ "epoch": 0.25,
414
+ "learning_rate": 9.708692027202137e-05,
415
+ "loss": 30.8906,
416
+ "step": 3400
417
+ },
418
+ {
419
+ "epoch": 0.25,
420
+ "learning_rate": 9.703669475947003e-05,
421
+ "loss": 26.8419,
422
+ "step": 3450
423
+ },
424
+ {
425
+ "epoch": 0.26,
426
+ "learning_rate": 9.698646924691866e-05,
427
+ "loss": 29.2181,
428
+ "step": 3500
429
+ },
430
+ {
431
+ "epoch": 0.26,
432
+ "learning_rate": 9.693624373436732e-05,
433
+ "loss": 27.6549,
434
+ "step": 3550
435
+ },
436
+ {
437
+ "epoch": 0.27,
438
+ "learning_rate": 9.688601822181595e-05,
439
+ "loss": 34.0701,
440
+ "step": 3600
441
+ },
442
+ {
443
+ "epoch": 0.27,
444
+ "learning_rate": 9.683579270926461e-05,
445
+ "loss": 24.7487,
446
+ "step": 3650
447
+ },
448
+ {
449
+ "epoch": 0.27,
450
+ "learning_rate": 9.678556719671325e-05,
451
+ "loss": 30.0266,
452
+ "step": 3700
453
+ },
454
+ {
455
+ "epoch": 0.28,
456
+ "learning_rate": 9.67353416841619e-05,
457
+ "loss": 25.5011,
458
+ "step": 3750
459
+ },
460
+ {
461
+ "epoch": 0.28,
462
+ "learning_rate": 9.668511617161054e-05,
463
+ "loss": 26.1437,
464
+ "step": 3800
465
+ },
466
+ {
467
+ "epoch": 0.28,
468
+ "learning_rate": 9.663489065905918e-05,
469
+ "loss": 23.2303,
470
+ "step": 3850
471
+ },
472
+ {
473
+ "epoch": 0.29,
474
+ "learning_rate": 9.658466514650783e-05,
475
+ "loss": 26.357,
476
+ "step": 3900
477
+ },
478
+ {
479
+ "epoch": 0.29,
480
+ "learning_rate": 9.653443963395646e-05,
481
+ "loss": 27.2201,
482
+ "step": 3950
483
+ },
484
+ {
485
+ "epoch": 0.3,
486
+ "learning_rate": 9.648421412140512e-05,
487
+ "loss": 25.5695,
488
+ "step": 4000
489
+ },
490
+ {
491
+ "epoch": 0.3,
492
+ "learning_rate": 9.643398860885376e-05,
493
+ "loss": 24.8346,
494
+ "step": 4050
495
+ },
496
+ {
497
+ "epoch": 0.3,
498
+ "learning_rate": 9.63837630963024e-05,
499
+ "loss": 22.3957,
500
+ "step": 4100
501
+ },
502
+ {
503
+ "epoch": 0.31,
504
+ "learning_rate": 9.633353758375105e-05,
505
+ "loss": 24.9532,
506
+ "step": 4150
507
+ },
508
+ {
509
+ "epoch": 0.31,
510
+ "learning_rate": 9.628331207119969e-05,
511
+ "loss": 23.1574,
512
+ "step": 4200
513
+ },
514
+ {
515
+ "epoch": 0.31,
516
+ "learning_rate": 9.623308655864834e-05,
517
+ "loss": 23.7018,
518
+ "step": 4250
519
+ },
520
+ {
521
+ "epoch": 0.32,
522
+ "learning_rate": 9.618286104609698e-05,
523
+ "loss": 25.1433,
524
+ "step": 4300
525
+ },
526
+ {
527
+ "epoch": 0.32,
528
+ "learning_rate": 9.613263553354562e-05,
529
+ "loss": 25.0571,
530
+ "step": 4350
531
+ },
532
+ {
533
+ "epoch": 0.32,
534
+ "learning_rate": 9.608241002099427e-05,
535
+ "loss": 24.2231,
536
+ "step": 4400
537
+ },
538
+ {
539
+ "epoch": 0.33,
540
+ "learning_rate": 9.603218450844291e-05,
541
+ "loss": 23.0983,
542
+ "step": 4450
543
+ },
544
+ {
545
+ "epoch": 0.33,
546
+ "learning_rate": 9.598195899589156e-05,
547
+ "loss": 25.0078,
548
+ "step": 4500
549
+ },
550
+ {
551
+ "epoch": 0.34,
552
+ "learning_rate": 9.59317334833402e-05,
553
+ "loss": 20.6933,
554
+ "step": 4550
555
+ },
556
+ {
557
+ "epoch": 0.34,
558
+ "learning_rate": 9.588150797078884e-05,
559
+ "loss": 23.6196,
560
+ "step": 4600
561
+ },
562
+ {
563
+ "epoch": 0.34,
564
+ "learning_rate": 9.583128245823749e-05,
565
+ "loss": 25.2331,
566
+ "step": 4650
567
+ },
568
+ {
569
+ "epoch": 0.35,
570
+ "learning_rate": 9.578105694568613e-05,
571
+ "loss": 24.7932,
572
+ "step": 4700
573
+ },
574
+ {
575
+ "epoch": 0.35,
576
+ "learning_rate": 9.573083143313478e-05,
577
+ "loss": 24.3586,
578
+ "step": 4750
579
+ },
580
+ {
581
+ "epoch": 0.35,
582
+ "learning_rate": 9.568060592058342e-05,
583
+ "loss": 22.7161,
584
+ "step": 4800
585
+ },
586
+ {
587
+ "epoch": 0.36,
588
+ "learning_rate": 9.563038040803208e-05,
589
+ "loss": 22.4188,
590
+ "step": 4850
591
+ },
592
+ {
593
+ "epoch": 0.36,
594
+ "learning_rate": 9.558015489548071e-05,
595
+ "loss": 21.6516,
596
+ "step": 4900
597
+ },
598
+ {
599
+ "epoch": 0.37,
600
+ "learning_rate": 9.552992938292937e-05,
601
+ "loss": 21.78,
602
+ "step": 4950
603
+ },
604
+ {
605
+ "epoch": 0.37,
606
+ "learning_rate": 9.5479703870378e-05,
607
+ "loss": 21.0172,
608
+ "step": 5000
609
+ },
610
+ {
611
+ "epoch": 0.37,
612
+ "learning_rate": 9.542947835782665e-05,
613
+ "loss": 22.4624,
614
+ "step": 5050
615
+ },
616
+ {
617
+ "epoch": 0.38,
618
+ "learning_rate": 9.537925284527528e-05,
619
+ "loss": 23.6615,
620
+ "step": 5100
621
+ },
622
+ {
623
+ "epoch": 0.38,
624
+ "learning_rate": 9.532902733272393e-05,
625
+ "loss": 21.8091,
626
+ "step": 5150
627
+ },
628
+ {
629
+ "epoch": 0.38,
630
+ "learning_rate": 9.527880182017259e-05,
631
+ "loss": 21.4173,
632
+ "step": 5200
633
+ },
634
+ {
635
+ "epoch": 0.39,
636
+ "learning_rate": 9.522857630762122e-05,
637
+ "loss": 20.5415,
638
+ "step": 5250
639
+ },
640
+ {
641
+ "epoch": 0.39,
642
+ "learning_rate": 9.517835079506987e-05,
643
+ "loss": 21.0639,
644
+ "step": 5300
645
+ },
646
+ {
647
+ "epoch": 0.39,
648
+ "learning_rate": 9.51281252825185e-05,
649
+ "loss": 21.6078,
650
+ "step": 5350
651
+ },
652
+ {
653
+ "epoch": 0.4,
654
+ "learning_rate": 9.507789976996716e-05,
655
+ "loss": 19.4142,
656
+ "step": 5400
657
+ },
658
+ {
659
+ "epoch": 0.4,
660
+ "learning_rate": 9.50276742574158e-05,
661
+ "loss": 20.2504,
662
+ "step": 5450
663
+ },
664
+ {
665
+ "epoch": 0.41,
666
+ "learning_rate": 9.497744874486445e-05,
667
+ "loss": 23.8683,
668
+ "step": 5500
669
+ },
670
+ {
671
+ "epoch": 0.41,
672
+ "learning_rate": 9.49272232323131e-05,
673
+ "loss": 19.7559,
674
+ "step": 5550
675
+ },
676
+ {
677
+ "epoch": 0.41,
678
+ "learning_rate": 9.487699771976174e-05,
679
+ "loss": 21.1743,
680
+ "step": 5600
681
+ },
682
+ {
683
+ "epoch": 0.42,
684
+ "learning_rate": 9.482677220721038e-05,
685
+ "loss": 21.1908,
686
+ "step": 5650
687
+ },
688
+ {
689
+ "epoch": 0.42,
690
+ "learning_rate": 9.477654669465901e-05,
691
+ "loss": 20.9591,
692
+ "step": 5700
693
+ },
694
+ {
695
+ "epoch": 0.42,
696
+ "learning_rate": 9.472632118210767e-05,
697
+ "loss": 20.9036,
698
+ "step": 5750
699
+ },
700
+ {
701
+ "epoch": 0.43,
702
+ "learning_rate": 9.46760956695563e-05,
703
+ "loss": 22.249,
704
+ "step": 5800
705
+ },
706
+ {
707
+ "epoch": 0.43,
708
+ "learning_rate": 9.462587015700496e-05,
709
+ "loss": 19.1093,
710
+ "step": 5850
711
+ },
712
+ {
713
+ "epoch": 0.44,
714
+ "learning_rate": 9.45756446444536e-05,
715
+ "loss": 21.2714,
716
+ "step": 5900
717
+ },
718
+ {
719
+ "epoch": 0.44,
720
+ "learning_rate": 9.452541913190225e-05,
721
+ "loss": 21.3794,
722
+ "step": 5950
723
+ },
724
+ {
725
+ "epoch": 0.44,
726
+ "learning_rate": 9.447519361935089e-05,
727
+ "loss": 20.0326,
728
+ "step": 6000
729
+ },
730
+ {
731
+ "epoch": 0.45,
732
+ "learning_rate": 9.442496810679954e-05,
733
+ "loss": 19.8004,
734
+ "step": 6050
735
+ },
736
+ {
737
+ "epoch": 0.45,
738
+ "learning_rate": 9.437474259424818e-05,
739
+ "loss": 19.0229,
740
+ "step": 6100
741
+ },
742
+ {
743
+ "epoch": 0.45,
744
+ "learning_rate": 9.432451708169682e-05,
745
+ "loss": 17.6587,
746
+ "step": 6150
747
+ },
748
+ {
749
+ "epoch": 0.46,
750
+ "learning_rate": 9.427429156914547e-05,
751
+ "loss": 21.9247,
752
+ "step": 6200
753
+ },
754
+ {
755
+ "epoch": 0.46,
756
+ "learning_rate": 9.422406605659411e-05,
757
+ "loss": 19.743,
758
+ "step": 6250
759
+ },
760
+ {
761
+ "epoch": 0.46,
762
+ "learning_rate": 9.417384054404276e-05,
763
+ "loss": 22.9746,
764
+ "step": 6300
765
+ },
766
+ {
767
+ "epoch": 0.47,
768
+ "learning_rate": 9.41236150314914e-05,
769
+ "loss": 19.6693,
770
+ "step": 6350
771
+ },
772
+ {
773
+ "epoch": 0.47,
774
+ "learning_rate": 9.407338951894004e-05,
775
+ "loss": 19.1141,
776
+ "step": 6400
777
+ },
778
+ {
779
+ "epoch": 0.48,
780
+ "learning_rate": 9.402316400638869e-05,
781
+ "loss": 18.3847,
782
+ "step": 6450
783
+ },
784
+ {
785
+ "epoch": 0.48,
786
+ "learning_rate": 9.397293849383733e-05,
787
+ "loss": 18.9357,
788
+ "step": 6500
789
+ },
790
+ {
791
+ "epoch": 0.48,
792
+ "learning_rate": 9.392271298128598e-05,
793
+ "loss": 18.9316,
794
+ "step": 6550
795
+ },
796
+ {
797
+ "epoch": 0.49,
798
+ "learning_rate": 9.387248746873462e-05,
799
+ "loss": 20.9141,
800
+ "step": 6600
801
+ },
802
+ {
803
+ "epoch": 0.49,
804
+ "learning_rate": 9.382226195618326e-05,
805
+ "loss": 18.7472,
806
+ "step": 6650
807
+ },
808
+ {
809
+ "epoch": 0.49,
810
+ "learning_rate": 9.377203644363192e-05,
811
+ "loss": 18.8577,
812
+ "step": 6700
813
+ },
814
+ {
815
+ "epoch": 0.5,
816
+ "learning_rate": 9.372181093108055e-05,
817
+ "loss": 17.8061,
818
+ "step": 6750
819
+ },
820
+ {
821
+ "epoch": 0.5,
822
+ "learning_rate": 9.36715854185292e-05,
823
+ "loss": 19.4687,
824
+ "step": 6800
825
+ },
826
+ {
827
+ "epoch": 0.51,
828
+ "learning_rate": 9.362135990597784e-05,
829
+ "loss": 19.5103,
830
+ "step": 6850
831
+ },
832
+ {
833
+ "epoch": 0.51,
834
+ "learning_rate": 9.357113439342648e-05,
835
+ "loss": 18.5319,
836
+ "step": 6900
837
+ },
838
+ {
839
+ "epoch": 0.51,
840
+ "learning_rate": 9.352090888087514e-05,
841
+ "loss": 20.16,
842
+ "step": 6950
843
+ },
844
+ {
845
+ "epoch": 0.52,
846
+ "learning_rate": 9.347068336832377e-05,
847
+ "loss": 18.1913,
848
+ "step": 7000
849
+ },
850
+ {
851
+ "epoch": 0.52,
852
+ "learning_rate": 9.342045785577243e-05,
853
+ "loss": 21.341,
854
+ "step": 7050
855
+ },
856
+ {
857
+ "epoch": 0.52,
858
+ "learning_rate": 9.337023234322106e-05,
859
+ "loss": 16.7701,
860
+ "step": 7100
861
+ },
862
+ {
863
+ "epoch": 0.53,
864
+ "learning_rate": 9.332000683066972e-05,
865
+ "loss": 18.045,
866
+ "step": 7150
867
+ },
868
+ {
869
+ "epoch": 0.53,
870
+ "learning_rate": 9.326978131811835e-05,
871
+ "loss": 16.0393,
872
+ "step": 7200
873
+ },
874
+ {
875
+ "epoch": 0.53,
876
+ "learning_rate": 9.3219555805567e-05,
877
+ "loss": 17.4833,
878
+ "step": 7250
879
+ },
880
+ {
881
+ "epoch": 0.54,
882
+ "learning_rate": 9.316933029301565e-05,
883
+ "loss": 17.3978,
884
+ "step": 7300
885
+ },
886
+ {
887
+ "epoch": 0.54,
888
+ "learning_rate": 9.31191047804643e-05,
889
+ "loss": 18.2649,
890
+ "step": 7350
891
+ },
892
+ {
893
+ "epoch": 0.55,
894
+ "learning_rate": 9.306887926791294e-05,
895
+ "loss": 16.3891,
896
+ "step": 7400
897
+ },
898
+ {
899
+ "epoch": 0.55,
900
+ "learning_rate": 9.301865375536157e-05,
901
+ "loss": 21.4399,
902
+ "step": 7450
903
+ },
904
+ {
905
+ "epoch": 0.55,
906
+ "learning_rate": 9.296842824281023e-05,
907
+ "loss": 16.3082,
908
+ "step": 7500
909
+ },
910
+ {
911
+ "epoch": 0.56,
912
+ "learning_rate": 9.291820273025886e-05,
913
+ "loss": 14.8713,
914
+ "step": 7550
915
+ },
916
+ {
917
+ "epoch": 0.56,
918
+ "learning_rate": 9.286797721770751e-05,
919
+ "loss": 16.3099,
920
+ "step": 7600
921
+ },
922
+ {
923
+ "epoch": 0.56,
924
+ "learning_rate": 9.281775170515616e-05,
925
+ "loss": 17.8771,
926
+ "step": 7650
927
+ },
928
+ {
929
+ "epoch": 0.57,
930
+ "learning_rate": 9.27675261926048e-05,
931
+ "loss": 17.1421,
932
+ "step": 7700
933
+ },
934
+ {
935
+ "epoch": 0.57,
936
+ "learning_rate": 9.271730068005345e-05,
937
+ "loss": 16.6478,
938
+ "step": 7750
939
+ },
940
+ {
941
+ "epoch": 0.58,
942
+ "learning_rate": 9.266707516750209e-05,
943
+ "loss": 15.3247,
944
+ "step": 7800
945
+ },
946
+ {
947
+ "epoch": 0.58,
948
+ "learning_rate": 9.261684965495073e-05,
949
+ "loss": 17.6577,
950
+ "step": 7850
951
+ },
952
+ {
953
+ "epoch": 0.58,
954
+ "learning_rate": 9.256662414239938e-05,
955
+ "loss": 18.8549,
956
+ "step": 7900
957
+ },
958
+ {
959
+ "epoch": 0.59,
960
+ "learning_rate": 9.251639862984802e-05,
961
+ "loss": 17.4187,
962
+ "step": 7950
963
+ },
964
+ {
965
+ "epoch": 0.59,
966
+ "learning_rate": 9.246617311729667e-05,
967
+ "loss": 15.6643,
968
+ "step": 8000
969
+ },
970
+ {
971
+ "epoch": 0.59,
972
+ "learning_rate": 9.241594760474531e-05,
973
+ "loss": 17.1987,
974
+ "step": 8050
975
+ },
976
+ {
977
+ "epoch": 0.6,
978
+ "learning_rate": 9.236572209219396e-05,
979
+ "loss": 18.1712,
980
+ "step": 8100
981
+ },
982
+ {
983
+ "epoch": 0.6,
984
+ "learning_rate": 9.23154965796426e-05,
985
+ "loss": 15.8015,
986
+ "step": 8150
987
+ },
988
+ {
989
+ "epoch": 0.6,
990
+ "learning_rate": 9.226527106709124e-05,
991
+ "loss": 19.064,
992
+ "step": 8200
993
+ },
994
+ {
995
+ "epoch": 0.61,
996
+ "learning_rate": 9.221504555453989e-05,
997
+ "loss": 18.2748,
998
+ "step": 8250
999
+ },
1000
+ {
1001
+ "epoch": 0.61,
1002
+ "learning_rate": 9.216482004198853e-05,
1003
+ "loss": 15.0679,
1004
+ "step": 8300
1005
+ },
1006
+ {
1007
+ "epoch": 0.62,
1008
+ "learning_rate": 9.211459452943718e-05,
1009
+ "loss": 17.995,
1010
+ "step": 8350
1011
+ },
1012
+ {
1013
+ "epoch": 0.62,
1014
+ "learning_rate": 9.206436901688582e-05,
1015
+ "loss": 17.467,
1016
+ "step": 8400
1017
+ },
1018
+ {
1019
+ "epoch": 0.62,
1020
+ "learning_rate": 9.201414350433448e-05,
1021
+ "loss": 18.6665,
1022
+ "step": 8450
1023
+ },
1024
+ {
1025
+ "epoch": 0.63,
1026
+ "learning_rate": 9.196391799178311e-05,
1027
+ "loss": 17.2848,
1028
+ "step": 8500
1029
+ },
1030
+ {
1031
+ "epoch": 0.63,
1032
+ "learning_rate": 9.191369247923175e-05,
1033
+ "loss": 14.4767,
1034
+ "step": 8550
1035
+ },
1036
+ {
1037
+ "epoch": 0.63,
1038
+ "learning_rate": 9.18634669666804e-05,
1039
+ "loss": 17.5444,
1040
+ "step": 8600
1041
+ },
1042
+ {
1043
+ "epoch": 0.64,
1044
+ "learning_rate": 9.181324145412904e-05,
1045
+ "loss": 14.4661,
1046
+ "step": 8650
1047
+ },
1048
+ {
1049
+ "epoch": 0.64,
1050
+ "learning_rate": 9.176301594157768e-05,
1051
+ "loss": 16.3339,
1052
+ "step": 8700
1053
+ },
1054
+ {
1055
+ "epoch": 0.65,
1056
+ "learning_rate": 9.171279042902633e-05,
1057
+ "loss": 17.5122,
1058
+ "step": 8750
1059
+ },
1060
+ {
1061
+ "epoch": 0.65,
1062
+ "learning_rate": 9.166256491647499e-05,
1063
+ "loss": 16.7631,
1064
+ "step": 8800
1065
+ },
1066
+ {
1067
+ "epoch": 0.65,
1068
+ "learning_rate": 9.161233940392362e-05,
1069
+ "loss": 16.5193,
1070
+ "step": 8850
1071
+ },
1072
+ {
1073
+ "epoch": 0.66,
1074
+ "learning_rate": 9.156211389137227e-05,
1075
+ "loss": 17.8364,
1076
+ "step": 8900
1077
+ },
1078
+ {
1079
+ "epoch": 0.66,
1080
+ "learning_rate": 9.15118883788209e-05,
1081
+ "loss": 16.2916,
1082
+ "step": 8950
1083
+ },
1084
+ {
1085
+ "epoch": 0.66,
1086
+ "learning_rate": 9.146166286626956e-05,
1087
+ "loss": 14.1719,
1088
+ "step": 9000
1089
+ },
1090
+ {
1091
+ "epoch": 0.67,
1092
+ "learning_rate": 9.141143735371819e-05,
1093
+ "loss": 18.2987,
1094
+ "step": 9050
1095
+ },
1096
+ {
1097
+ "epoch": 0.67,
1098
+ "learning_rate": 9.136121184116684e-05,
1099
+ "loss": 17.4248,
1100
+ "step": 9100
1101
+ },
1102
+ {
1103
+ "epoch": 0.67,
1104
+ "learning_rate": 9.13109863286155e-05,
1105
+ "loss": 16.1862,
1106
+ "step": 9150
1107
+ },
1108
+ {
1109
+ "epoch": 0.68,
1110
+ "learning_rate": 9.126076081606412e-05,
1111
+ "loss": 16.3134,
1112
+ "step": 9200
1113
+ },
1114
+ {
1115
+ "epoch": 0.68,
1116
+ "learning_rate": 9.121053530351278e-05,
1117
+ "loss": 14.9158,
1118
+ "step": 9250
1119
+ },
1120
+ {
1121
+ "epoch": 0.69,
1122
+ "learning_rate": 9.116030979096141e-05,
1123
+ "loss": 15.2504,
1124
+ "step": 9300
1125
+ },
1126
+ {
1127
+ "epoch": 0.69,
1128
+ "learning_rate": 9.111008427841007e-05,
1129
+ "loss": 14.1967,
1130
+ "step": 9350
1131
+ },
1132
+ {
1133
+ "epoch": 0.69,
1134
+ "learning_rate": 9.105985876585871e-05,
1135
+ "loss": 17.3165,
1136
+ "step": 9400
1137
+ },
1138
+ {
1139
+ "epoch": 0.7,
1140
+ "learning_rate": 9.100963325330736e-05,
1141
+ "loss": 14.5912,
1142
+ "step": 9450
1143
+ },
1144
+ {
1145
+ "epoch": 0.7,
1146
+ "learning_rate": 9.0959407740756e-05,
1147
+ "loss": 17.5593,
1148
+ "step": 9500
1149
+ },
1150
+ {
1151
+ "epoch": 0.7,
1152
+ "learning_rate": 9.090918222820465e-05,
1153
+ "loss": 16.3421,
1154
+ "step": 9550
1155
+ },
1156
+ {
1157
+ "epoch": 0.71,
1158
+ "learning_rate": 9.085895671565329e-05,
1159
+ "loss": 16.2821,
1160
+ "step": 9600
1161
+ },
1162
+ {
1163
+ "epoch": 0.71,
1164
+ "learning_rate": 9.080873120310192e-05,
1165
+ "loss": 16.4985,
1166
+ "step": 9650
1167
+ },
1168
+ {
1169
+ "epoch": 0.72,
1170
+ "learning_rate": 9.075850569055058e-05,
1171
+ "loss": 16.1138,
1172
+ "step": 9700
1173
+ },
1174
+ {
1175
+ "epoch": 0.72,
1176
+ "learning_rate": 9.070828017799922e-05,
1177
+ "loss": 16.3997,
1178
+ "step": 9750
1179
+ },
1180
+ {
1181
+ "epoch": 0.72,
1182
+ "learning_rate": 9.065805466544787e-05,
1183
+ "loss": 15.518,
1184
+ "step": 9800
1185
+ },
1186
+ {
1187
+ "epoch": 0.73,
1188
+ "learning_rate": 9.060782915289651e-05,
1189
+ "loss": 13.8424,
1190
+ "step": 9850
1191
+ },
1192
+ {
1193
+ "epoch": 0.73,
1194
+ "learning_rate": 9.055760364034515e-05,
1195
+ "loss": 15.0784,
1196
+ "step": 9900
1197
+ },
1198
+ {
1199
+ "epoch": 0.73,
1200
+ "learning_rate": 9.05073781277938e-05,
1201
+ "loss": 14.0163,
1202
+ "step": 9950
1203
+ },
1204
+ {
1205
+ "epoch": 0.74,
1206
+ "learning_rate": 9.045715261524244e-05,
1207
+ "loss": 16.7863,
1208
+ "step": 10000
1209
+ },
1210
+ {
1211
+ "epoch": 0.74,
1212
+ "learning_rate": 9.040692710269109e-05,
1213
+ "loss": 13.6715,
1214
+ "step": 10050
1215
+ },
1216
+ {
1217
+ "epoch": 0.75,
1218
+ "learning_rate": 9.035670159013973e-05,
1219
+ "loss": 15.1071,
1220
+ "step": 10100
1221
+ },
1222
+ {
1223
+ "epoch": 0.75,
1224
+ "learning_rate": 9.030647607758837e-05,
1225
+ "loss": 14.2658,
1226
+ "step": 10150
1227
+ },
1228
+ {
1229
+ "epoch": 0.75,
1230
+ "learning_rate": 9.025625056503703e-05,
1231
+ "loss": 15.1115,
1232
+ "step": 10200
1233
+ },
1234
+ {
1235
+ "epoch": 0.76,
1236
+ "learning_rate": 9.020602505248566e-05,
1237
+ "loss": 14.028,
1238
+ "step": 10250
1239
+ },
1240
+ {
1241
+ "epoch": 0.76,
1242
+ "learning_rate": 9.015579953993431e-05,
1243
+ "loss": 13.3066,
1244
+ "step": 10300
1245
+ },
1246
+ {
1247
+ "epoch": 0.76,
1248
+ "learning_rate": 9.010557402738295e-05,
1249
+ "loss": 14.1185,
1250
+ "step": 10350
1251
+ },
1252
+ {
1253
+ "epoch": 0.77,
1254
+ "learning_rate": 9.00553485148316e-05,
1255
+ "loss": 14.061,
1256
+ "step": 10400
1257
+ },
1258
+ {
1259
+ "epoch": 0.77,
1260
+ "learning_rate": 9.000512300228024e-05,
1261
+ "loss": 15.2439,
1262
+ "step": 10450
1263
+ },
1264
+ {
1265
+ "epoch": 0.77,
1266
+ "learning_rate": 8.995489748972888e-05,
1267
+ "loss": 13.3617,
1268
+ "step": 10500
1269
+ },
1270
+ {
1271
+ "epoch": 0.78,
1272
+ "learning_rate": 8.990467197717754e-05,
1273
+ "loss": 14.5514,
1274
+ "step": 10550
1275
+ },
1276
+ {
1277
+ "epoch": 0.78,
1278
+ "learning_rate": 8.985444646462617e-05,
1279
+ "loss": 15.2426,
1280
+ "step": 10600
1281
+ },
1282
+ {
1283
+ "epoch": 0.79,
1284
+ "learning_rate": 8.980422095207483e-05,
1285
+ "loss": 16.6418,
1286
+ "step": 10650
1287
+ },
1288
+ {
1289
+ "epoch": 0.79,
1290
+ "learning_rate": 8.975399543952346e-05,
1291
+ "loss": 13.3146,
1292
+ "step": 10700
1293
+ },
1294
+ {
1295
+ "epoch": 0.79,
1296
+ "learning_rate": 8.970376992697212e-05,
1297
+ "loss": 14.9333,
1298
+ "step": 10750
1299
+ },
1300
+ {
1301
+ "epoch": 0.8,
1302
+ "learning_rate": 8.965354441442075e-05,
1303
+ "loss": 14.4502,
1304
+ "step": 10800
1305
+ },
1306
+ {
1307
+ "epoch": 0.8,
1308
+ "learning_rate": 8.960331890186939e-05,
1309
+ "loss": 14.7886,
1310
+ "step": 10850
1311
+ },
1312
+ {
1313
+ "epoch": 0.8,
1314
+ "learning_rate": 8.955309338931805e-05,
1315
+ "loss": 15.0266,
1316
+ "step": 10900
1317
+ },
1318
+ {
1319
+ "epoch": 0.81,
1320
+ "learning_rate": 8.950286787676668e-05,
1321
+ "loss": 14.543,
1322
+ "step": 10950
1323
+ },
1324
+ {
1325
+ "epoch": 0.81,
1326
+ "learning_rate": 8.945264236421534e-05,
1327
+ "loss": 15.8078,
1328
+ "step": 11000
1329
+ },
1330
+ {
1331
+ "epoch": 0.82,
1332
+ "learning_rate": 8.940241685166397e-05,
1333
+ "loss": 13.6052,
1334
+ "step": 11050
1335
+ },
1336
+ {
1337
+ "epoch": 0.82,
1338
+ "learning_rate": 8.935219133911263e-05,
1339
+ "loss": 14.2995,
1340
+ "step": 11100
1341
+ },
1342
+ {
1343
+ "epoch": 0.82,
1344
+ "learning_rate": 8.930196582656126e-05,
1345
+ "loss": 15.732,
1346
+ "step": 11150
1347
+ },
1348
+ {
1349
+ "epoch": 0.83,
1350
+ "learning_rate": 8.925174031400991e-05,
1351
+ "loss": 14.0573,
1352
+ "step": 11200
1353
+ },
1354
+ {
1355
+ "epoch": 0.83,
1356
+ "learning_rate": 8.920151480145856e-05,
1357
+ "loss": 17.5941,
1358
+ "step": 11250
1359
+ },
1360
+ {
1361
+ "epoch": 0.83,
1362
+ "learning_rate": 8.91512892889072e-05,
1363
+ "loss": 14.7829,
1364
+ "step": 11300
1365
+ },
1366
+ {
1367
+ "epoch": 0.84,
1368
+ "learning_rate": 8.910106377635585e-05,
1369
+ "loss": 14.6669,
1370
+ "step": 11350
1371
+ },
1372
+ {
1373
+ "epoch": 0.84,
1374
+ "learning_rate": 8.905083826380448e-05,
1375
+ "loss": 14.3315,
1376
+ "step": 11400
1377
+ },
1378
+ {
1379
+ "epoch": 0.84,
1380
+ "learning_rate": 8.900061275125313e-05,
1381
+ "loss": 14.2639,
1382
+ "step": 11450
1383
+ },
1384
+ {
1385
+ "epoch": 0.85,
1386
+ "learning_rate": 8.895038723870176e-05,
1387
+ "loss": 14.3226,
1388
+ "step": 11500
1389
+ },
1390
+ {
1391
+ "epoch": 0.85,
1392
+ "learning_rate": 8.890016172615042e-05,
1393
+ "loss": 14.4975,
1394
+ "step": 11550
1395
+ },
1396
+ {
1397
+ "epoch": 0.86,
1398
+ "learning_rate": 8.884993621359907e-05,
1399
+ "loss": 14.8436,
1400
+ "step": 11600
1401
+ },
1402
+ {
1403
+ "epoch": 0.86,
1404
+ "learning_rate": 8.879971070104771e-05,
1405
+ "loss": 13.8481,
1406
+ "step": 11650
1407
+ },
1408
+ {
1409
+ "epoch": 0.86,
1410
+ "learning_rate": 8.874948518849635e-05,
1411
+ "loss": 12.8151,
1412
+ "step": 11700
1413
+ },
1414
+ {
1415
+ "epoch": 0.87,
1416
+ "learning_rate": 8.8699259675945e-05,
1417
+ "loss": 13.1659,
1418
+ "step": 11750
1419
+ },
1420
+ {
1421
+ "epoch": 0.87,
1422
+ "learning_rate": 8.864903416339364e-05,
1423
+ "loss": 15.0919,
1424
+ "step": 11800
1425
+ },
1426
+ {
1427
+ "epoch": 0.87,
1428
+ "learning_rate": 8.859880865084229e-05,
1429
+ "loss": 14.4382,
1430
+ "step": 11850
1431
+ },
1432
+ {
1433
+ "epoch": 0.88,
1434
+ "learning_rate": 8.854858313829093e-05,
1435
+ "loss": 14.0989,
1436
+ "step": 11900
1437
+ },
1438
+ {
1439
+ "epoch": 0.88,
1440
+ "learning_rate": 8.849835762573957e-05,
1441
+ "loss": 14.5763,
1442
+ "step": 11950
1443
+ },
1444
+ {
1445
+ "epoch": 0.89,
1446
+ "learning_rate": 8.844813211318822e-05,
1447
+ "loss": 13.4144,
1448
+ "step": 12000
1449
+ },
1450
+ {
1451
+ "epoch": 0.89,
1452
+ "learning_rate": 8.839790660063686e-05,
1453
+ "loss": 15.6018,
1454
+ "step": 12050
1455
+ },
1456
+ {
1457
+ "epoch": 0.89,
1458
+ "learning_rate": 8.83476810880855e-05,
1459
+ "loss": 14.7849,
1460
+ "step": 12100
1461
+ },
1462
+ {
1463
+ "epoch": 0.9,
1464
+ "learning_rate": 8.829745557553415e-05,
1465
+ "loss": 14.441,
1466
+ "step": 12150
1467
+ },
1468
+ {
1469
+ "epoch": 0.9,
1470
+ "learning_rate": 8.82472300629828e-05,
1471
+ "loss": 14.2135,
1472
+ "step": 12200
1473
+ },
1474
+ {
1475
+ "epoch": 0.9,
1476
+ "learning_rate": 8.819700455043144e-05,
1477
+ "loss": 17.1245,
1478
+ "step": 12250
1479
+ },
1480
+ {
1481
+ "epoch": 0.91,
1482
+ "learning_rate": 8.814677903788008e-05,
1483
+ "loss": 14.6629,
1484
+ "step": 12300
1485
+ },
1486
+ {
1487
+ "epoch": 0.91,
1488
+ "learning_rate": 8.809655352532873e-05,
1489
+ "loss": 16.6715,
1490
+ "step": 12350
1491
+ },
1492
+ {
1493
+ "epoch": 0.91,
1494
+ "learning_rate": 8.804632801277738e-05,
1495
+ "loss": 13.0133,
1496
+ "step": 12400
1497
+ },
1498
+ {
1499
+ "epoch": 0.92,
1500
+ "learning_rate": 8.799610250022601e-05,
1501
+ "loss": 14.1551,
1502
+ "step": 12450
1503
+ },
1504
+ {
1505
+ "epoch": 0.92,
1506
+ "learning_rate": 8.794587698767466e-05,
1507
+ "loss": 14.019,
1508
+ "step": 12500
1509
+ },
1510
+ {
1511
+ "epoch": 0.93,
1512
+ "learning_rate": 8.78956514751233e-05,
1513
+ "loss": 14.4279,
1514
+ "step": 12550
1515
+ },
1516
+ {
1517
+ "epoch": 0.93,
1518
+ "learning_rate": 8.784542596257195e-05,
1519
+ "loss": 12.5293,
1520
+ "step": 12600
1521
+ },
1522
+ {
1523
+ "epoch": 0.93,
1524
+ "learning_rate": 8.77952004500206e-05,
1525
+ "loss": 15.0403,
1526
+ "step": 12650
1527
+ },
1528
+ {
1529
+ "epoch": 0.94,
1530
+ "learning_rate": 8.774497493746924e-05,
1531
+ "loss": 13.8193,
1532
+ "step": 12700
1533
+ },
1534
+ {
1535
+ "epoch": 0.94,
1536
+ "learning_rate": 8.769474942491789e-05,
1537
+ "loss": 13.1564,
1538
+ "step": 12750
1539
+ },
1540
+ {
1541
+ "epoch": 0.94,
1542
+ "learning_rate": 8.764452391236652e-05,
1543
+ "loss": 14.6415,
1544
+ "step": 12800
1545
+ },
1546
+ {
1547
+ "epoch": 0.95,
1548
+ "learning_rate": 8.759429839981518e-05,
1549
+ "loss": 12.2339,
1550
+ "step": 12850
1551
+ },
1552
+ {
1553
+ "epoch": 0.95,
1554
+ "learning_rate": 8.754407288726381e-05,
1555
+ "loss": 12.1604,
1556
+ "step": 12900
1557
+ },
1558
+ {
1559
+ "epoch": 0.96,
1560
+ "learning_rate": 8.749384737471247e-05,
1561
+ "loss": 15.4939,
1562
+ "step": 12950
1563
+ },
1564
+ {
1565
+ "epoch": 0.96,
1566
+ "learning_rate": 8.744362186216111e-05,
1567
+ "loss": 13.9713,
1568
+ "step": 13000
1569
+ },
1570
+ {
1571
+ "epoch": 0.96,
1572
+ "learning_rate": 8.739339634960976e-05,
1573
+ "loss": 14.0986,
1574
+ "step": 13050
1575
+ },
1576
+ {
1577
+ "epoch": 0.97,
1578
+ "learning_rate": 8.73431708370584e-05,
1579
+ "loss": 13.6334,
1580
+ "step": 13100
1581
+ },
1582
+ {
1583
+ "epoch": 0.97,
1584
+ "learning_rate": 8.729294532450703e-05,
1585
+ "loss": 13.5201,
1586
+ "step": 13150
1587
+ },
1588
+ {
1589
+ "epoch": 0.97,
1590
+ "learning_rate": 8.724271981195569e-05,
1591
+ "loss": 14.3793,
1592
+ "step": 13200
1593
+ },
1594
+ {
1595
+ "epoch": 0.98,
1596
+ "learning_rate": 8.719249429940432e-05,
1597
+ "loss": 13.1741,
1598
+ "step": 13250
1599
+ },
1600
+ {
1601
+ "epoch": 0.98,
1602
+ "learning_rate": 8.714226878685298e-05,
1603
+ "loss": 11.7782,
1604
+ "step": 13300
1605
+ },
1606
+ {
1607
+ "epoch": 0.98,
1608
+ "learning_rate": 8.709204327430162e-05,
1609
+ "loss": 12.2758,
1610
+ "step": 13350
1611
+ },
1612
+ {
1613
+ "epoch": 0.99,
1614
+ "learning_rate": 8.704181776175027e-05,
1615
+ "loss": 13.1723,
1616
+ "step": 13400
1617
+ },
1618
+ {
1619
+ "epoch": 0.99,
1620
+ "learning_rate": 8.699159224919891e-05,
1621
+ "loss": 14.0858,
1622
+ "step": 13450
1623
+ },
1624
+ {
1625
+ "epoch": 1.0,
1626
+ "learning_rate": 8.694136673664755e-05,
1627
+ "loss": 11.2836,
1628
+ "step": 13500
1629
+ },
1630
+ {
1631
+ "epoch": 1.0,
1632
+ "learning_rate": 8.68911412240962e-05,
1633
+ "loss": 15.7226,
1634
+ "step": 13550
1635
+ },
1636
+ {
1637
+ "epoch": 1.0,
1638
+ "learning_rate": 8.684091571154484e-05,
1639
+ "loss": 15.8889,
1640
+ "step": 13600
1641
+ },
1642
+ {
1643
+ "epoch": 1.01,
1644
+ "learning_rate": 8.679069019899349e-05,
1645
+ "loss": 12.2185,
1646
+ "step": 13650
1647
+ },
1648
+ {
1649
+ "epoch": 1.01,
1650
+ "learning_rate": 8.674046468644213e-05,
1651
+ "loss": 11.4647,
1652
+ "step": 13700
1653
+ },
1654
+ {
1655
+ "epoch": 1.01,
1656
+ "learning_rate": 8.669023917389077e-05,
1657
+ "loss": 13.1238,
1658
+ "step": 13750
1659
+ },
1660
+ {
1661
+ "epoch": 1.02,
1662
+ "learning_rate": 8.664001366133942e-05,
1663
+ "loss": 11.909,
1664
+ "step": 13800
1665
+ },
1666
+ {
1667
+ "epoch": 1.02,
1668
+ "learning_rate": 8.658978814878806e-05,
1669
+ "loss": 12.5478,
1670
+ "step": 13850
1671
+ },
1672
+ {
1673
+ "epoch": 1.03,
1674
+ "learning_rate": 8.65395626362367e-05,
1675
+ "loss": 13.017,
1676
+ "step": 13900
1677
+ },
1678
+ {
1679
+ "epoch": 1.03,
1680
+ "learning_rate": 8.648933712368535e-05,
1681
+ "loss": 12.9134,
1682
+ "step": 13950
1683
+ },
1684
+ {
1685
+ "epoch": 1.03,
1686
+ "learning_rate": 8.6439111611134e-05,
1687
+ "loss": 13.3485,
1688
+ "step": 14000
1689
+ },
1690
+ {
1691
+ "epoch": 1.04,
1692
+ "learning_rate": 8.638888609858264e-05,
1693
+ "loss": 11.4706,
1694
+ "step": 14050
1695
+ },
1696
+ {
1697
+ "epoch": 1.04,
1698
+ "learning_rate": 8.633866058603128e-05,
1699
+ "loss": 11.1063,
1700
+ "step": 14100
1701
+ },
1702
+ {
1703
+ "epoch": 1.04,
1704
+ "learning_rate": 8.628843507347994e-05,
1705
+ "loss": 12.7408,
1706
+ "step": 14150
1707
+ },
1708
+ {
1709
+ "epoch": 1.05,
1710
+ "learning_rate": 8.623820956092857e-05,
1711
+ "loss": 12.0689,
1712
+ "step": 14200
1713
+ },
1714
+ {
1715
+ "epoch": 1.05,
1716
+ "learning_rate": 8.618798404837721e-05,
1717
+ "loss": 11.0724,
1718
+ "step": 14250
1719
+ },
1720
+ {
1721
+ "epoch": 1.05,
1722
+ "learning_rate": 8.613775853582586e-05,
1723
+ "loss": 12.5685,
1724
+ "step": 14300
1725
+ },
1726
+ {
1727
+ "epoch": 1.06,
1728
+ "learning_rate": 8.60875330232745e-05,
1729
+ "loss": 12.7776,
1730
+ "step": 14350
1731
+ },
1732
+ {
1733
+ "epoch": 1.06,
1734
+ "learning_rate": 8.603730751072315e-05,
1735
+ "loss": 11.3066,
1736
+ "step": 14400
1737
+ },
1738
+ {
1739
+ "epoch": 1.07,
1740
+ "learning_rate": 8.598708199817179e-05,
1741
+ "loss": 13.06,
1742
+ "step": 14450
1743
+ },
1744
+ {
1745
+ "epoch": 1.07,
1746
+ "learning_rate": 8.593685648562045e-05,
1747
+ "loss": 15.6523,
1748
+ "step": 14500
1749
+ },
1750
+ {
1751
+ "epoch": 1.07,
1752
+ "learning_rate": 8.588663097306908e-05,
1753
+ "loss": 12.019,
1754
+ "step": 14550
1755
+ },
1756
+ {
1757
+ "epoch": 1.08,
1758
+ "learning_rate": 8.583640546051774e-05,
1759
+ "loss": 11.0941,
1760
+ "step": 14600
1761
+ },
1762
+ {
1763
+ "epoch": 1.08,
1764
+ "learning_rate": 8.578617994796637e-05,
1765
+ "loss": 12.4755,
1766
+ "step": 14650
1767
+ },
1768
+ {
1769
+ "epoch": 1.08,
1770
+ "learning_rate": 8.573595443541502e-05,
1771
+ "loss": 13.7012,
1772
+ "step": 14700
1773
+ },
1774
+ {
1775
+ "epoch": 1.09,
1776
+ "learning_rate": 8.568572892286366e-05,
1777
+ "loss": 12.2024,
1778
+ "step": 14750
1779
+ },
1780
+ {
1781
+ "epoch": 1.09,
1782
+ "learning_rate": 8.56355034103123e-05,
1783
+ "loss": 12.4744,
1784
+ "step": 14800
1785
+ },
1786
+ {
1787
+ "epoch": 1.1,
1788
+ "learning_rate": 8.558527789776096e-05,
1789
+ "loss": 12.3234,
1790
+ "step": 14850
1791
+ },
1792
+ {
1793
+ "epoch": 1.1,
1794
+ "learning_rate": 8.553505238520959e-05,
1795
+ "loss": 12.5616,
1796
+ "step": 14900
1797
+ },
1798
+ {
1799
+ "epoch": 1.1,
1800
+ "learning_rate": 8.548482687265824e-05,
1801
+ "loss": 11.9559,
1802
+ "step": 14950
1803
+ },
1804
+ {
1805
+ "epoch": 1.11,
1806
+ "learning_rate": 8.543460136010688e-05,
1807
+ "loss": 12.0734,
1808
+ "step": 15000
1809
+ },
1810
+ {
1811
+ "epoch": 1.11,
1812
+ "learning_rate": 8.538437584755553e-05,
1813
+ "loss": 13.0341,
1814
+ "step": 15050
1815
+ },
1816
+ {
1817
+ "epoch": 1.11,
1818
+ "learning_rate": 8.533415033500418e-05,
1819
+ "loss": 12.7406,
1820
+ "step": 15100
1821
+ },
1822
+ {
1823
+ "epoch": 1.12,
1824
+ "learning_rate": 8.528392482245282e-05,
1825
+ "loss": 11.7258,
1826
+ "step": 15150
1827
+ },
1828
+ {
1829
+ "epoch": 1.12,
1830
+ "learning_rate": 8.523369930990147e-05,
1831
+ "loss": 11.8709,
1832
+ "step": 15200
1833
+ },
1834
+ {
1835
+ "epoch": 1.12,
1836
+ "learning_rate": 8.518347379735011e-05,
1837
+ "loss": 11.7021,
1838
+ "step": 15250
1839
+ },
1840
+ {
1841
+ "epoch": 1.13,
1842
+ "learning_rate": 8.513324828479875e-05,
1843
+ "loss": 13.2674,
1844
+ "step": 15300
1845
+ },
1846
+ {
1847
+ "epoch": 1.13,
1848
+ "learning_rate": 8.508302277224738e-05,
1849
+ "loss": 11.9099,
1850
+ "step": 15350
1851
+ },
1852
+ {
1853
+ "epoch": 1.14,
1854
+ "learning_rate": 8.503279725969604e-05,
1855
+ "loss": 11.7841,
1856
+ "step": 15400
1857
+ },
1858
+ {
1859
+ "epoch": 1.14,
1860
+ "learning_rate": 8.498257174714469e-05,
1861
+ "loss": 11.9573,
1862
+ "step": 15450
1863
+ },
1864
+ {
1865
+ "epoch": 1.14,
1866
+ "learning_rate": 8.493234623459333e-05,
1867
+ "loss": 11.7211,
1868
+ "step": 15500
1869
+ },
1870
+ {
1871
+ "epoch": 1.15,
1872
+ "learning_rate": 8.488212072204197e-05,
1873
+ "loss": 12.3513,
1874
+ "step": 15550
1875
+ },
1876
+ {
1877
+ "epoch": 1.15,
1878
+ "learning_rate": 8.483189520949062e-05,
1879
+ "loss": 11.0709,
1880
+ "step": 15600
1881
+ },
1882
+ {
1883
+ "epoch": 1.15,
1884
+ "learning_rate": 8.478166969693926e-05,
1885
+ "loss": 11.6544,
1886
+ "step": 15650
1887
+ },
1888
+ {
1889
+ "epoch": 1.16,
1890
+ "learning_rate": 8.47314441843879e-05,
1891
+ "loss": 11.8285,
1892
+ "step": 15700
1893
+ },
1894
+ {
1895
+ "epoch": 1.16,
1896
+ "learning_rate": 8.468121867183655e-05,
1897
+ "loss": 10.4208,
1898
+ "step": 15750
1899
+ },
1900
+ {
1901
+ "epoch": 1.17,
1902
+ "learning_rate": 8.46309931592852e-05,
1903
+ "loss": 10.7821,
1904
+ "step": 15800
1905
+ },
1906
+ {
1907
+ "epoch": 1.17,
1908
+ "learning_rate": 8.458076764673384e-05,
1909
+ "loss": 13.2724,
1910
+ "step": 15850
1911
+ },
1912
+ {
1913
+ "epoch": 1.17,
1914
+ "learning_rate": 8.45305421341825e-05,
1915
+ "loss": 10.9219,
1916
+ "step": 15900
1917
+ },
1918
+ {
1919
+ "epoch": 1.18,
1920
+ "learning_rate": 8.448031662163113e-05,
1921
+ "loss": 12.2532,
1922
+ "step": 15950
1923
+ },
1924
+ {
1925
+ "epoch": 1.18,
1926
+ "learning_rate": 8.443009110907977e-05,
1927
+ "loss": 11.0132,
1928
+ "step": 16000
1929
+ },
1930
+ {
1931
+ "epoch": 1.18,
1932
+ "learning_rate": 8.437986559652841e-05,
1933
+ "loss": 12.319,
1934
+ "step": 16050
1935
+ },
1936
+ {
1937
+ "epoch": 1.19,
1938
+ "learning_rate": 8.432964008397706e-05,
1939
+ "loss": 12.9871,
1940
+ "step": 16100
1941
+ },
1942
+ {
1943
+ "epoch": 1.19,
1944
+ "learning_rate": 8.42794145714257e-05,
1945
+ "loss": 12.0625,
1946
+ "step": 16150
1947
+ },
1948
+ {
1949
+ "epoch": 1.19,
1950
+ "learning_rate": 8.422918905887435e-05,
1951
+ "loss": 13.4629,
1952
+ "step": 16200
1953
+ },
1954
+ {
1955
+ "epoch": 1.2,
1956
+ "learning_rate": 8.4178963546323e-05,
1957
+ "loss": 10.9291,
1958
+ "step": 16250
1959
+ },
1960
+ {
1961
+ "epoch": 1.2,
1962
+ "learning_rate": 8.412873803377163e-05,
1963
+ "loss": 13.7719,
1964
+ "step": 16300
1965
+ },
1966
+ {
1967
+ "epoch": 1.21,
1968
+ "learning_rate": 8.407851252122029e-05,
1969
+ "loss": 11.3634,
1970
+ "step": 16350
1971
+ },
1972
+ {
1973
+ "epoch": 1.21,
1974
+ "learning_rate": 8.402828700866892e-05,
1975
+ "loss": 12.7941,
1976
+ "step": 16400
1977
+ },
1978
+ {
1979
+ "epoch": 1.21,
1980
+ "learning_rate": 8.397806149611758e-05,
1981
+ "loss": 11.8863,
1982
+ "step": 16450
1983
+ },
1984
+ {
1985
+ "epoch": 1.22,
1986
+ "learning_rate": 8.392783598356621e-05,
1987
+ "loss": 9.5225,
1988
+ "step": 16500
1989
+ },
1990
+ {
1991
+ "epoch": 1.22,
1992
+ "learning_rate": 8.387761047101485e-05,
1993
+ "loss": 12.983,
1994
+ "step": 16550
1995
+ },
1996
+ {
1997
+ "epoch": 1.22,
1998
+ "learning_rate": 8.382738495846351e-05,
1999
+ "loss": 11.8489,
2000
+ "step": 16600
2001
+ },
2002
+ {
2003
+ "epoch": 1.23,
2004
+ "learning_rate": 8.377715944591214e-05,
2005
+ "loss": 11.8122,
2006
+ "step": 16650
2007
+ },
2008
+ {
2009
+ "epoch": 1.23,
2010
+ "learning_rate": 8.37269339333608e-05,
2011
+ "loss": 12.3387,
2012
+ "step": 16700
2013
+ },
2014
+ {
2015
+ "epoch": 1.24,
2016
+ "learning_rate": 8.367670842080943e-05,
2017
+ "loss": 13.4648,
2018
+ "step": 16750
2019
+ },
2020
+ {
2021
+ "epoch": 1.24,
2022
+ "learning_rate": 8.362648290825809e-05,
2023
+ "loss": 10.2301,
2024
+ "step": 16800
2025
+ },
2026
+ {
2027
+ "epoch": 1.24,
2028
+ "learning_rate": 8.357625739570672e-05,
2029
+ "loss": 11.492,
2030
+ "step": 16850
2031
+ },
2032
+ {
2033
+ "epoch": 1.25,
2034
+ "learning_rate": 8.352603188315538e-05,
2035
+ "loss": 12.5997,
2036
+ "step": 16900
2037
+ },
2038
+ {
2039
+ "epoch": 1.25,
2040
+ "learning_rate": 8.347580637060402e-05,
2041
+ "loss": 11.5588,
2042
+ "step": 16950
2043
+ },
2044
+ {
2045
+ "epoch": 1.25,
2046
+ "learning_rate": 8.342558085805266e-05,
2047
+ "loss": 11.8627,
2048
+ "step": 17000
2049
+ },
2050
+ {
2051
+ "epoch": 1.26,
2052
+ "learning_rate": 8.337535534550131e-05,
2053
+ "loss": 13.2469,
2054
+ "step": 17050
2055
+ },
2056
+ {
2057
+ "epoch": 1.26,
2058
+ "learning_rate": 8.332512983294994e-05,
2059
+ "loss": 10.4327,
2060
+ "step": 17100
2061
+ },
2062
+ {
2063
+ "epoch": 1.27,
2064
+ "learning_rate": 8.32749043203986e-05,
2065
+ "loss": 12.7566,
2066
+ "step": 17150
2067
+ },
2068
+ {
2069
+ "epoch": 1.27,
2070
+ "learning_rate": 8.322467880784723e-05,
2071
+ "loss": 11.0729,
2072
+ "step": 17200
2073
+ },
2074
+ {
2075
+ "epoch": 1.27,
2076
+ "learning_rate": 8.317445329529588e-05,
2077
+ "loss": 12.3484,
2078
+ "step": 17250
2079
+ },
2080
+ {
2081
+ "epoch": 1.28,
2082
+ "learning_rate": 8.312422778274453e-05,
2083
+ "loss": 10.5193,
2084
+ "step": 17300
2085
+ },
2086
+ {
2087
+ "epoch": 1.28,
2088
+ "learning_rate": 8.307400227019317e-05,
2089
+ "loss": 12.2369,
2090
+ "step": 17350
2091
+ },
2092
+ {
2093
+ "epoch": 1.28,
2094
+ "learning_rate": 8.302377675764182e-05,
2095
+ "loss": 12.2976,
2096
+ "step": 17400
2097
+ },
2098
+ {
2099
+ "epoch": 1.29,
2100
+ "learning_rate": 8.297355124509046e-05,
2101
+ "loss": 12.3852,
2102
+ "step": 17450
2103
+ },
2104
+ {
2105
+ "epoch": 1.29,
2106
+ "learning_rate": 8.29233257325391e-05,
2107
+ "loss": 11.2137,
2108
+ "step": 17500
2109
+ },
2110
+ {
2111
+ "epoch": 1.29,
2112
+ "learning_rate": 8.287310021998775e-05,
2113
+ "loss": 11.609,
2114
+ "step": 17550
2115
+ },
2116
+ {
2117
+ "epoch": 1.3,
2118
+ "learning_rate": 8.282287470743639e-05,
2119
+ "loss": 13.3339,
2120
+ "step": 17600
2121
+ },
2122
+ {
2123
+ "epoch": 1.3,
2124
+ "learning_rate": 8.277264919488504e-05,
2125
+ "loss": 11.4263,
2126
+ "step": 17650
2127
+ },
2128
+ {
2129
+ "epoch": 1.31,
2130
+ "learning_rate": 8.272242368233368e-05,
2131
+ "loss": 12.6949,
2132
+ "step": 17700
2133
+ },
2134
+ {
2135
+ "epoch": 1.31,
2136
+ "learning_rate": 8.267219816978233e-05,
2137
+ "loss": 11.4767,
2138
+ "step": 17750
2139
+ },
2140
+ {
2141
+ "epoch": 1.31,
2142
+ "learning_rate": 8.262197265723097e-05,
2143
+ "loss": 12.2225,
2144
+ "step": 17800
2145
+ },
2146
+ {
2147
+ "epoch": 1.32,
2148
+ "learning_rate": 8.257174714467961e-05,
2149
+ "loss": 11.0755,
2150
+ "step": 17850
2151
+ },
2152
+ {
2153
+ "epoch": 1.32,
2154
+ "learning_rate": 8.252152163212826e-05,
2155
+ "loss": 11.9677,
2156
+ "step": 17900
2157
+ },
2158
+ {
2159
+ "epoch": 1.32,
2160
+ "learning_rate": 8.24712961195769e-05,
2161
+ "loss": 11.098,
2162
+ "step": 17950
2163
+ },
2164
+ {
2165
+ "epoch": 1.33,
2166
+ "learning_rate": 8.242107060702555e-05,
2167
+ "loss": 11.1102,
2168
+ "step": 18000
2169
+ },
2170
+ {
2171
+ "epoch": 1.33,
2172
+ "learning_rate": 8.237084509447419e-05,
2173
+ "loss": 11.4985,
2174
+ "step": 18050
2175
+ },
2176
+ {
2177
+ "epoch": 1.34,
2178
+ "learning_rate": 8.232061958192285e-05,
2179
+ "loss": 11.7356,
2180
+ "step": 18100
2181
+ },
2182
+ {
2183
+ "epoch": 1.34,
2184
+ "learning_rate": 8.227039406937148e-05,
2185
+ "loss": 11.3336,
2186
+ "step": 18150
2187
+ },
2188
+ {
2189
+ "epoch": 1.34,
2190
+ "learning_rate": 8.222016855682012e-05,
2191
+ "loss": 11.0448,
2192
+ "step": 18200
2193
+ },
2194
+ {
2195
+ "epoch": 1.35,
2196
+ "learning_rate": 8.216994304426877e-05,
2197
+ "loss": 10.9986,
2198
+ "step": 18250
2199
+ },
2200
+ {
2201
+ "epoch": 1.35,
2202
+ "learning_rate": 8.211971753171741e-05,
2203
+ "loss": 10.768,
2204
+ "step": 18300
2205
+ },
2206
+ {
2207
+ "epoch": 1.35,
2208
+ "learning_rate": 8.206949201916607e-05,
2209
+ "loss": 11.6844,
2210
+ "step": 18350
2211
+ },
2212
+ {
2213
+ "epoch": 1.36,
2214
+ "learning_rate": 8.20192665066147e-05,
2215
+ "loss": 11.5615,
2216
+ "step": 18400
2217
+ },
2218
+ {
2219
+ "epoch": 1.36,
2220
+ "learning_rate": 8.196904099406336e-05,
2221
+ "loss": 11.4019,
2222
+ "step": 18450
2223
+ },
2224
+ {
2225
+ "epoch": 1.36,
2226
+ "learning_rate": 8.191881548151199e-05,
2227
+ "loss": 12.1784,
2228
+ "step": 18500
2229
+ },
2230
+ {
2231
+ "epoch": 1.37,
2232
+ "learning_rate": 8.186858996896064e-05,
2233
+ "loss": 12.4565,
2234
+ "step": 18550
2235
+ },
2236
+ {
2237
+ "epoch": 1.37,
2238
+ "learning_rate": 8.181836445640927e-05,
2239
+ "loss": 11.0557,
2240
+ "step": 18600
2241
+ },
2242
+ {
2243
+ "epoch": 1.38,
2244
+ "learning_rate": 8.176813894385793e-05,
2245
+ "loss": 12.1892,
2246
+ "step": 18650
2247
+ },
2248
+ {
2249
+ "epoch": 1.38,
2250
+ "learning_rate": 8.171791343130658e-05,
2251
+ "loss": 12.0531,
2252
+ "step": 18700
2253
+ },
2254
+ {
2255
+ "epoch": 1.38,
2256
+ "learning_rate": 8.166768791875522e-05,
2257
+ "loss": 10.1791,
2258
+ "step": 18750
2259
+ },
2260
+ {
2261
+ "epoch": 1.39,
2262
+ "learning_rate": 8.161746240620386e-05,
2263
+ "loss": 11.2501,
2264
+ "step": 18800
2265
+ },
2266
+ {
2267
+ "epoch": 1.39,
2268
+ "learning_rate": 8.15672368936525e-05,
2269
+ "loss": 9.92,
2270
+ "step": 18850
2271
+ },
2272
+ {
2273
+ "epoch": 1.39,
2274
+ "learning_rate": 8.151701138110115e-05,
2275
+ "loss": 10.0603,
2276
+ "step": 18900
2277
+ },
2278
+ {
2279
+ "epoch": 1.4,
2280
+ "learning_rate": 8.146678586854978e-05,
2281
+ "loss": 10.9477,
2282
+ "step": 18950
2283
+ },
2284
+ {
2285
+ "epoch": 1.4,
2286
+ "learning_rate": 8.141656035599844e-05,
2287
+ "loss": 9.7579,
2288
+ "step": 19000
2289
+ },
2290
+ {
2291
+ "epoch": 1.41,
2292
+ "learning_rate": 8.136633484344708e-05,
2293
+ "loss": 11.243,
2294
+ "step": 19050
2295
+ },
2296
+ {
2297
+ "epoch": 1.41,
2298
+ "learning_rate": 8.131610933089573e-05,
2299
+ "loss": 11.0069,
2300
+ "step": 19100
2301
+ },
2302
+ {
2303
+ "epoch": 1.41,
2304
+ "learning_rate": 8.126588381834437e-05,
2305
+ "loss": 9.7387,
2306
+ "step": 19150
2307
+ },
2308
+ {
2309
+ "epoch": 1.42,
2310
+ "learning_rate": 8.121565830579302e-05,
2311
+ "loss": 11.4624,
2312
+ "step": 19200
2313
+ },
2314
+ {
2315
+ "epoch": 1.42,
2316
+ "learning_rate": 8.116543279324166e-05,
2317
+ "loss": 12.1299,
2318
+ "step": 19250
2319
+ },
2320
+ {
2321
+ "epoch": 1.42,
2322
+ "learning_rate": 8.11152072806903e-05,
2323
+ "loss": 12.2796,
2324
+ "step": 19300
2325
+ },
2326
+ {
2327
+ "epoch": 1.43,
2328
+ "learning_rate": 8.106498176813895e-05,
2329
+ "loss": 10.3295,
2330
+ "step": 19350
2331
+ },
2332
+ {
2333
+ "epoch": 1.43,
2334
+ "learning_rate": 8.101475625558759e-05,
2335
+ "loss": 10.0709,
2336
+ "step": 19400
2337
+ },
2338
+ {
2339
+ "epoch": 1.43,
2340
+ "learning_rate": 8.096453074303624e-05,
2341
+ "loss": 11.0725,
2342
+ "step": 19450
2343
+ },
2344
+ {
2345
+ "epoch": 1.44,
2346
+ "learning_rate": 8.091430523048488e-05,
2347
+ "loss": 10.7882,
2348
+ "step": 19500
2349
+ },
2350
+ {
2351
+ "epoch": 1.44,
2352
+ "learning_rate": 8.086407971793352e-05,
2353
+ "loss": 11.4124,
2354
+ "step": 19550
2355
+ },
2356
+ {
2357
+ "epoch": 1.45,
2358
+ "learning_rate": 8.081385420538217e-05,
2359
+ "loss": 10.4941,
2360
+ "step": 19600
2361
+ },
2362
+ {
2363
+ "epoch": 1.45,
2364
+ "learning_rate": 8.076362869283081e-05,
2365
+ "loss": 11.8687,
2366
+ "step": 19650
2367
+ },
2368
+ {
2369
+ "epoch": 1.45,
2370
+ "learning_rate": 8.071340318027946e-05,
2371
+ "loss": 11.3221,
2372
+ "step": 19700
2373
+ },
2374
+ {
2375
+ "epoch": 1.46,
2376
+ "learning_rate": 8.06631776677281e-05,
2377
+ "loss": 10.2167,
2378
+ "step": 19750
2379
+ },
2380
+ {
2381
+ "epoch": 1.46,
2382
+ "learning_rate": 8.061295215517675e-05,
2383
+ "loss": 10.5425,
2384
+ "step": 19800
2385
+ },
2386
+ {
2387
+ "epoch": 1.46,
2388
+ "learning_rate": 8.05627266426254e-05,
2389
+ "loss": 11.2982,
2390
+ "step": 19850
2391
+ },
2392
+ {
2393
+ "epoch": 1.47,
2394
+ "learning_rate": 8.051250113007403e-05,
2395
+ "loss": 12.0685,
2396
+ "step": 19900
2397
+ },
2398
+ {
2399
+ "epoch": 1.47,
2400
+ "learning_rate": 8.046227561752268e-05,
2401
+ "loss": 10.6613,
2402
+ "step": 19950
2403
+ },
2404
+ {
2405
+ "epoch": 1.48,
2406
+ "learning_rate": 8.041205010497132e-05,
2407
+ "loss": 10.8245,
2408
+ "step": 20000
2409
+ },
2410
+ {
2411
+ "epoch": 1.48,
2412
+ "eval_loss": 10.409339904785156,
2413
+ "eval_runtime": 890.9956,
2414
+ "eval_samples_per_second": 14.7,
2415
+ "eval_steps_per_second": 3.676,
2416
+ "eval_wer": 0.2624627273109067,
2417
+ "step": 20000
2418
+ }
2419
+ ],
2420
+ "max_steps": 100051,
2421
+ "num_train_epochs": 8,
2422
+ "total_flos": 0.0,
2423
+ "trial_name": null,
2424
+ "trial_params": null
2425
+ }
checkpoint-20000/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b64c669f66dd7a2e54d3001ce7e31c26cc60dd58136e8ce90e6055bd0ae15eb
3
+ size 3503
checkpoint-40000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1c09a7ddf632fa2b5485de6d094cf8a763affbefcb8dc5c93001a0539bad686
3
+ size 5154563651
checkpoint-40000/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31b1895952e1807b396d4e924fa1fb61ed026336fa2d9b568b14c899ec1ae878
3
+ size 14503
checkpoint-40000/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b81037f0665e42c49d437ecf24e1e38406f2a8f8a1c463379f77ea33597052a
3
+ size 623
checkpoint-40000/stt_en_conformer_transducer_xlarge.nemo ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c822e20c23a0eb709dc03222743ce215a42db9863af172c34297cd8c402f9e4
3
+ size 2577971200
checkpoint-40000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-40000/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b64c669f66dd7a2e54d3001ce7e31c26cc60dd58136e8ce90e6055bd0ae15eb
3
+ size 3503
checkpoint-60000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:406e36deb47741922cd59f748cd1876112106ea059c820e699c269fe0d635c2b
3
+ size 5154563651
checkpoint-60000/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3707b4b4d63eda9f45abb91e6157a5777abe5bcccebdf82df707bae7df65cf9e
3
+ size 14503
checkpoint-60000/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4edddf9241e66e2708bca7527dec737063f80262825a1b055e50529066c54390
3
+ size 623
checkpoint-60000/stt_en_conformer_transducer_xlarge.nemo ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be893728d43d533cf97573378f9587552441031cf01aa9fdc25c779e733140f1
3
+ size 2577971200
checkpoint-60000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-60000/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b64c669f66dd7a2e54d3001ce7e31c26cc60dd58136e8ce90e6055bd0ae15eb
3
+ size 3503
checkpoint-80000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56861ad8a03582034a89047c1e6397a79297e194daab37dae36192eb72f16c4a
3
+ size 5154565443
checkpoint-80000/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b36e92749442e712801d00e24ed95ea736e78f8ef065b6af0b801ae709dfb48d
3
+ size 14503
checkpoint-80000/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:827a7ad0b8599273336e50134d47c6b281fcbf26c0ef32fd1bca5bf3db63fe69
3
+ size 623
checkpoint-80000/stt_en_conformer_transducer_xlarge.nemo ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9af5c4d6859c9af2c18bca5723158554500ba93753fb4ffd4923e3e72011340
3
+ size 2577971200
checkpoint-80000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-80000/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b64c669f66dd7a2e54d3001ce7e31c26cc60dd58136e8ce90e6055bd0ae15eb
3
+ size 3503
conf/conformer_transducer_bpe_dummy.yaml ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # It contains the default values for training a Conformer-Transducer ASR model, dummy size, with Transducer loss and sub-word encoding.
2
+
3
+ name: "Conformer-Transducer-BPE"
4
+
5
+ model:
6
+ sample_rate: 16000
7
+ compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
8
+ log_prediction: true # enables logging sample predictions in the output during training
9
+ skip_nan_grad: false
10
+
11
+ model_defaults:
12
+ enc_hidden: ${model.encoder.d_model}
13
+ pred_hidden: 64
14
+ joint_hidden: 64
15
+
16
+ train_ds:
17
+ manifest_filepath: ???
18
+ sample_rate: ${model.sample_rate}
19
+ batch_size: 16 # you may increase batch_size if your memory allows
20
+ shuffle: true
21
+ num_workers: 8
22
+ pin_memory: true
23
+ use_start_end_token: false
24
+ trim_silence: false
25
+ max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
26
+ min_duration: 0.1
27
+ # tarred datasets
28
+ is_tarred: false
29
+ tarred_audio_filepaths: null
30
+ shuffle_n: 2048
31
+ # bucketing params
32
+ bucketing_strategy: "synced_randomized"
33
+ bucketing_batch_size: null
34
+
35
+ validation_ds:
36
+ manifest_filepath: ???
37
+ sample_rate: ${model.sample_rate}
38
+ batch_size: 16
39
+ shuffle: false
40
+ num_workers: 8
41
+ pin_memory: true
42
+ use_start_end_token: false
43
+
44
+ test_ds:
45
+ manifest_filepath: null
46
+ sample_rate: ${model.sample_rate}
47
+ batch_size: 16
48
+ shuffle: false
49
+ num_workers: 8
50
+ pin_memory: true
51
+ use_start_end_token: false
52
+
53
+ # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
54
+ tokenizer:
55
+ dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
56
+ type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)
57
+
58
+ preprocessor:
59
+ _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
60
+ sample_rate: ${model.sample_rate}
61
+ normalize: "per_feature"
62
+ window_size: 0.025
63
+ window_stride: 0.01
64
+ window: "hann"
65
+ features: 80
66
+ n_fft: 512
67
+ frame_splicing: 1
68
+ dither: 0.00001
69
+ pad_to: 0
70
+
71
+ spec_augment:
72
+ _target_: nemo.collections.asr.modules.SpectrogramAugmentation
73
+ freq_masks: 2 # set to zero to disable it
74
+ time_masks: 10 # set to zero to disable it
75
+ freq_width: 27
76
+ time_width: 0.05
77
+
78
+ encoder:
79
+ _target_: nemo.collections.asr.modules.ConformerEncoder
80
+ feat_in: ${model.preprocessor.features}
81
+ feat_out: -1 # you may set it if you need different output size other than the default d_model
82
+ n_layers: 2
83
+ d_model: 64
84
+
85
+ # Sub-sampling params
86
+ subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding
87
+ subsampling_factor: 4 # must be power of 2 for striding and vggnet
88
+ subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model
89
+ causal_downsampling: false
90
+
91
+ # Feed forward module's params
92
+ ff_expansion_factor: 4
93
+
94
+ # Multi-headed Attention Module's params
95
+ self_attention_model: rel_pos # rel_pos or abs_pos
96
+ n_heads: 8 # may need to be lower for smaller d_models
97
+ # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
98
+ att_context_size: [-1, -1] # -1 means unlimited context
99
+ att_context_style: regular # regular or chunked_limited
100
+ xscaling: true # scales up the input embeddings by sqrt(d_model)
101
+ untie_biases: true # unties the biases of the TransformerXL layers
102
+ pos_emb_max_len: 5000
103
+
104
+ # Convolution module's params
105
+ conv_kernel_size: 5
106
+ conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
107
+ # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
108
+ # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
109
+ conv_context_size: null
110
+
111
+ ### regularization
112
+ dropout: 0.1 # The dropout used in most of the Conformer Modules
113
+ dropout_emb: 0.0 # The dropout used for embeddings
114
+ dropout_att: 0.1 # The dropout for multi-headed attention modules
115
+
116
+ decoder:
117
+ _target_: nemo.collections.asr.modules.RNNTDecoder
118
+ normalization_mode: null # Currently only null is supported for export.
119
+ random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
120
+ blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.
121
+
122
+ prednet:
123
+ pred_hidden: ${model.model_defaults.pred_hidden}
124
+ pred_rnn_layers: 1
125
+ t_max: null
126
+ dropout: 0.2
127
+
128
+ joint:
129
+ _target_: nemo.collections.asr.modules.RNNTJoint
130
+ log_softmax: null # 'null' would set it automatically according to CPU/GPU device
131
+ preserve_memory: false # dramatically slows down training, but might preserve some memory
132
+
133
+ # Fuses the computation of prediction net + joint net + loss + WER calculation
134
+ # to be run on sub-batches of size `fused_batch_size`.
135
+ # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
136
+ # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
137
+ # Using small values here will preserve a lot of memory during training, but will make training slower as well.
138
+ # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
139
+ # However, to preserve memory, this ratio can be 1:8 or even 1:16.
140
+ # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
141
+ fuse_loss_wer: true
142
+ fused_batch_size: 16
143
+
144
+ jointnet:
145
+ joint_hidden: ${model.model_defaults.joint_hidden}
146
+ activation: "relu"
147
+ dropout: 0.2
148
+
149
+ decoding:
150
+ strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd.
151
+
152
+ # greedy strategy config
153
+ greedy:
154
+ max_symbols: 10
155
+
156
+ # beam strategy config
157
+ beam:
158
+ beam_size: 2
159
+ return_best_hypothesis: False
160
+ score_norm: true
161
+ tsd_max_sym_exp: 50 # for Time Synchronous Decoding
162
+ alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding
163
+
164
+ loss:
165
+ loss_name: "default"
166
+
167
+ warprnnt_numba_kwargs:
168
+ # FastEmit regularization: https://arxiv.org/abs/2010.11148
169
+ # You may enable FastEmit to reduce the latency of the model for streaming
170
+ fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
171
+ clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.
172
+
173
+ # Adds Gaussian noise to the gradients of the decoder to avoid overfitting
174
+ variational_noise:
175
+ start_step: 0
176
+ std: 0.0
177
+
178
+ optim:
179
+ name: adamw
180
+ lr: 5.0
181
+ # optimizer arguments
182
+ betas: [0.9, 0.98]
183
+ weight_decay: 1e-3
184
+
185
+ # scheduler setup
186
+ sched:
187
+ name: NoamAnnealing
188
+ d_model: ${model.encoder.d_model}
189
+ # scheduler config override
190
+ warmup_steps: 10000
191
+ warmup_ratio: null
192
+ min_lr: 1e-6
conf/conformer_transducer_bpe_large.yaml ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # It contains the default values for training a Conformer-Transducer ASR model, large size (~120M) with Transducer loss and sub-word encoding.
2
+
3
+ # Architecture and training config:
4
+ # Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
5
+ # batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
6
+ # Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file.
7
+ #
8
+ # +-------------+---------+---------+----------+--------------+--------------------------+
9
+ # | Model | d_model | n_heads | n_layers | weight_decay | pred_hidden/joint_hidden |
10
+ # +=============+=========+========+===========+==============+==========================+
11
+ # | Small (14M)| 176 | 4 | 16 | 0.0 | 320 |
12
+ # +-------------+---------+--------+-----------+--------------+--------------------------+
13
+ # | Medium (32M)| 256 | 4 | 16 | 1e-3 | 640 |
14
+ # +-------------+---------+--------+-----------+--------------+--------------------------+
15
+ # | Large (120M)| 512 | 8 | 17 | 1e-3 | 640 |
16
+ # +-----------------------------------------------------------+--------------------------+
17
+ #
18
+
19
+ # You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer
20
+ # Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html
21
+ # The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large
22
+
23
+ name: "Conformer-Transducer-BPE"
24
+
25
+ model:
26
+ sample_rate: 16000
27
+ compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
28
+ log_prediction: true # enables logging sample predictions in the output during training
29
+ skip_nan_grad: false
30
+
31
+ model_defaults:
32
+ enc_hidden: ${model.encoder.d_model}
33
+ pred_hidden: 640
34
+ joint_hidden: 640
35
+
36
+ train_ds:
37
+ manifest_filepath: ???
38
+ sample_rate: ${model.sample_rate}
39
+ batch_size: 16 # you may increase batch_size if your memory allows
40
+ shuffle: true
41
+ num_workers: 8
42
+ pin_memory: true
43
+ use_start_end_token: false
44
+ trim_silence: false
45
+ max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
46
+ min_duration: 0.1
47
+ # tarred datasets
48
+ is_tarred: false
49
+ tarred_audio_filepaths: null
50
+ shuffle_n: 2048
51
+ # bucketing params
52
+ bucketing_strategy: "synced_randomized"
53
+ bucketing_batch_size: null
54
+
55
+ validation_ds:
56
+ manifest_filepath: ???
57
+ sample_rate: ${model.sample_rate}
58
+ batch_size: 16
59
+ shuffle: false
60
+ num_workers: 8
61
+ pin_memory: true
62
+ use_start_end_token: false
63
+
64
+ test_ds:
65
+ manifest_filepath: null
66
+ sample_rate: ${model.sample_rate}
67
+ batch_size: 16
68
+ shuffle: false
69
+ num_workers: 8
70
+ pin_memory: true
71
+ use_start_end_token: false
72
+
73
+ # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
74
+ tokenizer:
75
+ dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
76
+ type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)
77
+
78
+ preprocessor:
79
+ _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
80
+ sample_rate: ${model.sample_rate}
81
+ normalize: "per_feature"
82
+ window_size: 0.025
83
+ window_stride: 0.01
84
+ window: "hann"
85
+ features: 80
86
+ n_fft: 512
87
+ frame_splicing: 1
88
+ dither: 0.00001
89
+ pad_to: 0
90
+
91
+ spec_augment:
92
+ _target_: nemo.collections.asr.modules.SpectrogramAugmentation
93
+ freq_masks: 2 # set to zero to disable it
94
+ time_masks: 10 # set to zero to disable it
95
+ freq_width: 27
96
+ time_width: 0.05
97
+
98
+ encoder:
99
+ _target_: nemo.collections.asr.modules.ConformerEncoder
100
+ feat_in: ${model.preprocessor.features}
101
+ feat_out: -1 # you may set it if you need different output size other than the default d_model
102
+ n_layers: 17
103
+ d_model: 512
104
+
105
+ # Sub-sampling params
106
+ subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding
107
+ subsampling_factor: 4 # must be power of 2 for striding and vggnet
108
+ subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model
109
+ causal_downsampling: false
110
+
111
+ # Feed forward module's params
112
+ ff_expansion_factor: 4
113
+
114
+ # Multi-headed Attention Module's params
115
+ self_attention_model: rel_pos # rel_pos or abs_pos
116
+ n_heads: 8 # may need to be lower for smaller d_models
117
+ # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
118
+ att_context_size: [-1, -1] # -1 means unlimited context
119
+ att_context_style: regular # regular or chunked_limited
120
+ xscaling: true # scales up the input embeddings by sqrt(d_model)
121
+ untie_biases: true # unties the biases of the TransformerXL layers
122
+ pos_emb_max_len: 5000
123
+
124
+ # Convolution module's params
125
+ conv_kernel_size: 31
126
+ conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
127
+ # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
128
+ # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
129
+ conv_context_size: null
130
+
131
+ ### regularization
132
+ dropout: 0.1 # The dropout used in most of the Conformer Modules
133
+ dropout_emb: 0.0 # The dropout used for embeddings
134
+ dropout_att: 0.1 # The dropout for multi-headed attention modules
135
+
136
+ decoder:
137
+ _target_: nemo.collections.asr.modules.RNNTDecoder
138
+ normalization_mode: null # Currently only null is supported for export.
139
+ random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
140
+ blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.
141
+
142
+ prednet:
143
+ pred_hidden: ${model.model_defaults.pred_hidden}
144
+ pred_rnn_layers: 1
145
+ t_max: null
146
+ dropout: 0.2
147
+
148
+ joint:
149
+ _target_: nemo.collections.asr.modules.RNNTJoint
150
+ log_softmax: null # 'null' would set it automatically according to CPU/GPU device
151
+ preserve_memory: false # dramatically slows down training, but might preserve some memory
152
+
153
+ # Fuses the computation of prediction net + joint net + loss + WER calculation
154
+ # to be run on sub-batches of size `fused_batch_size`.
155
+ # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
156
+ # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
157
+ # Using small values here will preserve a lot of memory during training, but will make training slower as well.
158
+ # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
159
+ # However, to preserve memory, this ratio can be 1:8 or even 1:16.
160
+ # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
161
+ fuse_loss_wer: true
162
+ fused_batch_size: 16
163
+
164
+ jointnet:
165
+ joint_hidden: ${model.model_defaults.joint_hidden}
166
+ activation: "relu"
167
+ dropout: 0.2
168
+
169
+ decoding:
170
+ strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd.
171
+
172
+ # greedy strategy config
173
+ greedy:
174
+ max_symbols: 10
175
+
176
+ # beam strategy config
177
+ beam:
178
+ beam_size: 2
179
+ return_best_hypothesis: False
180
+ score_norm: true
181
+ tsd_max_sym_exp: 50 # for Time Synchronous Decoding
182
+ alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding
183
+
184
+ loss:
185
+ loss_name: "default"
186
+
187
+ warprnnt_numba_kwargs:
188
+ # FastEmit regularization: https://arxiv.org/abs/2010.11148
189
+ # You may enable FastEmit to reduce the latency of the model for streaming
190
+ fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
191
+ clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.
192
+
193
+ # Adds Gaussian noise to the gradients of the decoder to avoid overfitting
194
+ variational_noise:
195
+ start_step: 0
196
+ std: 0.0
197
+
198
+ optim:
199
+ name: adamw
200
+ lr: 5.0
201
+ # optimizer arguments
202
+ betas: [0.9, 0.98]
203
+ weight_decay: 1e-3
204
+
205
+ # scheduler setup
206
+ sched:
207
+ name: NoamAnnealing
208
+ d_model: ${model.encoder.d_model}
209
+ # scheduler config override
210
+ warmup_steps: 10000
211
+ warmup_ratio: null
212
+ min_lr: 1e-6
conf/conformer_transducer_bpe_xlarge.yaml ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # It contains the default values for training a Conformer-Transducer ASR model, XL size (~0.6B) with Transducer loss and sub-word encoding.
2
+
3
+ # You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer
4
+ # Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html
5
+ # The checkpoint of the xlarge model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_xlarge
6
+
7
+ name: "Conformer-Transducer-BPE"
8
+
9
+ model:
10
+ sample_rate: 16000
11
+ compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
12
+ log_prediction: true # enables logging sample predictions in the output during training
13
+ skip_nan_grad: false
14
+
15
+ model_defaults:
16
+ enc_hidden: ${model.encoder.d_model}
17
+ pred_hidden: 640
18
+ joint_hidden: 640
19
+
20
+ train_ds:
21
+ manifest_filepath: ???
22
+ sample_rate: ${model.sample_rate}
23
+ batch_size: 16 # you may increase batch_size if your memory allows
24
+ shuffle: true
25
+ num_workers: 8
26
+ pin_memory: true
27
+ use_start_end_token: false
28
+ trim_silence: false
29
+ max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset
30
+ min_duration: 0.1
31
+ # tarred datasets
32
+ is_tarred: false
33
+ tarred_audio_filepaths: null
34
+ shuffle_n: 2048
35
+ # bucketing params
36
+ bucketing_strategy: "synced_randomized"
37
+ bucketing_batch_size: null
38
+
39
+ validation_ds:
40
+ manifest_filepath: ???
41
+ sample_rate: ${model.sample_rate}
42
+ batch_size: 16
43
+ shuffle: false
44
+ num_workers: 8
45
+ pin_memory: true
46
+ use_start_end_token: false
47
+
48
+ test_ds:
49
+ manifest_filepath: null
50
+ sample_rate: ${model.sample_rate}
51
+ batch_size: 16
52
+ shuffle: false
53
+ num_workers: 8
54
+ pin_memory: true
55
+ use_start_end_token: false
56
+
57
+ # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
58
+ tokenizer:
59
+ dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
60
+ type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)
61
+
62
+ preprocessor:
63
+ _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
64
+ sample_rate: ${model.sample_rate}
65
+ normalize: "per_feature"
66
+ window_size: 0.025
67
+ window_stride: 0.01
68
+ window: "hann"
69
+ features: 80
70
+ n_fft: 512
71
+ frame_splicing: 1
72
+ dither: 0.00001
73
+ pad_to: 0
74
+
75
+ spec_augment:
76
+ _target_: nemo.collections.asr.modules.SpectrogramAugmentation
77
+ freq_masks: 2 # set to zero to disable it
78
+ time_masks: 10 # set to zero to disable it
79
+ freq_width: 27
80
+ time_width: 0.05
81
+
82
+ encoder:
83
+ _target_: nemo.collections.asr.modules.ConformerEncoder
84
+ feat_in: ${model.preprocessor.features}
85
+ feat_out: -1 # you may set it if you need different output size other than the default d_model
86
+ n_layers: 24
87
+ d_model: 1024
88
+
89
+ # Sub-sampling params
90
+ subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding
91
+ subsampling_factor: 4 # must be power of 2 for striding and vggnet
92
+ subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model
93
+ causal_downsampling: false
94
+
95
+ # Feed forward module's params
96
+ ff_expansion_factor: 4
97
+
98
+ # Multi-headed Attention Module's params
99
+ self_attention_model: rel_pos # rel_pos or abs_pos
100
+ n_heads: 8 # may need to be lower for smaller d_models
101
+ # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
102
+ att_context_size: [-1, -1] # -1 means unlimited context
103
+ att_context_style: regular # regular or chunked_limited
104
+ xscaling: true # scales up the input embeddings by sqrt(d_model)
105
+ untie_biases: true # unties the biases of the TransformerXL layers
106
+ pos_emb_max_len: 5000
107
+
108
+ # Convolution module's params
109
+ conv_kernel_size: 5
110
+ conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
111
+ # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
112
+ # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
113
+ conv_context_size: null
114
+
115
+ ### regularization
116
+ dropout: 0.1 # The dropout used in most of the Conformer Modules
117
+ dropout_emb: 0.0 # The dropout used for embeddings
118
+ dropout_att: 0.1 # The dropout for multi-headed attention modules
119
+
120
+ decoder:
121
+ _target_: nemo.collections.asr.modules.RNNTDecoder
122
+ normalization_mode: null # Currently only null is supported for export.
123
+ random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
124
+ blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.
125
+
126
+ prednet:
127
+ pred_hidden: ${model.model_defaults.pred_hidden}
128
+ pred_rnn_layers: 2
129
+ t_max: null
130
+ dropout: 0.1
131
+
132
+ joint:
133
+ _target_: nemo.collections.asr.modules.RNNTJoint
134
+ log_softmax: null # 'null' would set it automatically according to CPU/GPU device
135
+ preserve_memory: false # dramatically slows down training, but might preserve some memory
136
+
137
+ # Fuses the computation of prediction net + joint net + loss + WER calculation
138
+ # to be run on sub-batches of size `fused_batch_size`.
139
+ # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
140
+ # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
141
+ # Using small values here will preserve a lot of memory during training, but will make training slower as well.
142
+ # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
143
+ # However, to preserve memory, this ratio can be 1:8 or even 1:16.
144
+ # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
145
+ fuse_loss_wer: true
146
+ fused_batch_size: 16
147
+
148
+ jointnet:
149
+ joint_hidden: ${model.model_defaults.joint_hidden}
150
+ activation: "relu"
151
+ dropout: 0.1
152
+
153
+ decoding:
154
+ strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd.
155
+
156
+ # greedy strategy config
157
+ greedy:
158
+ max_symbols: 10
159
+
160
+ # beam strategy config
161
+ beam:
162
+ beam_size: 2
163
+ return_best_hypothesis: False
164
+ score_norm: true
165
+ tsd_max_sym_exp: 50 # for Time Synchronous Decoding
166
+ alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding
167
+
168
+ loss:
169
+ loss_name: "default"
170
+
171
+ warprnnt_numba_kwargs:
172
+ # FastEmit regularization: https://arxiv.org/abs/2010.11148
173
+ # You may enable FastEmit to reduce the latency of the model for streaming
174
+ fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
175
+ clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.
176
+
177
+ # Adds Gaussian noise to the gradients of the decoder to avoid overfitting
178
+ variational_noise:
179
+ start_step: 0
180
+ std: 0.0
181
+
182
+ optim:
183
+ name: adamw
184
+ lr: 5.0
185
+ # optimizer arguments
186
+ betas: [0.9, 0.98]
187
+ weight_decay: 1e-3
188
+
189
+ # scheduler setup
190
+ sched:
191
+ name: NoamAnnealing
192
+ d_model: ${model.encoder.d_model}
193
+ # scheduler config override
194
+ warmup_steps: 10000
195
+ warmup_ratio: null
196
+ min_lr: 1e-6
conf/contextnet_rnnt.yaml ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This config contains the default values for training a modified ContextNet model with Transducer loss and BPE-based vocabulary.
2
+ # In contrast to original ContextNet, the same number of filters is used throughout the model.
3
+ # Default learning parameters in this config are set for effective batch size of 1k on 32 GPUs.
4
+ # To train it with smaller batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
5
+
6
+ # It contains the default values for training a ContextNet ASR model, large size (~144M) with Transducer loss and sub-word encoding.
7
+
8
+ # Architecture and training config:
9
+ # Default learning parameters in this config are set for effective batch size of 1K. To train it with smaller effective
10
+ # batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
11
+ # Here are the recommended configs for different variants of ContextNet, other parameters are the same as in this config file.
12
+ #
13
+ # +-------------+---------+------------+
14
+ # | Model | filters | time_masks |
15
+ # +=============+=========+============+
16
+ # | Small (14M)| 256 | 2 |
17
+ # +-------------+---------+------------+
18
+ # | Medium (40M)| 512 | 5 |
19
+ # +-------------+---------+------------+
20
+ # | Large (145M)| 1024 | 10 |
21
+ # +-------------------------------------
22
+
23
+ name: &name "ContextNet-8x-Stride-RNNT"
24
+
25
+ model:
26
+ sample_rate: 16000
27
+ compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
28
+
29
+ train_ds:
30
+ manifest_filepath: ???
31
+ sample_rate: ${model.sample_rate}
32
+ batch_size: 16 # Can be increased if memory allows or when using smaller model
33
+ trim_silence: false
34
+ max_duration: 16.7
35
+ shuffle: true
36
+ use_start_end_token: false
37
+ num_workers: 16
38
+ pin_memory: true
39
+ # tarred datasets
40
+ is_tarred: false
41
+ tarred_audio_filepaths: null
42
+ tarred_shard_strategy: "scatter"
43
+ shuffle_n: 2048
44
+ # bucketing params
45
+ bucketing_strategy: "synced_randomized"
46
+ bucketing_batch_size: null
47
+ validation_ds:
48
+ manifest_filepath: ???
49
+ sample_rate: ${model.sample_rate}
50
+ batch_size: 8
51
+ shuffle: false
52
+ use_start_end_token: false
53
+ num_workers: 16
54
+ pin_memory: true
55
+
56
+ test_ds:
57
+ manifest_filepath: null
58
+ sample_rate: ${model.sample_rate}
59
+ batch_size: 8
60
+ shuffle: false
61
+ use_start_end_token: false
62
+ num_workers: 16
63
+ pin_memory: true
64
+
65
+ model_defaults:
66
+ filters: 1024
67
+ repeat: 5
68
+ dropout: 0.1
69
+ separable: true
70
+ se: true
71
+ se_context_size: -1
72
+ kernel_size_factor: 1.0
73
+ # encoder / decoder / joint values
74
+ enc_hidden: 640
75
+ pred_hidden: 640
76
+ joint_hidden: 640
77
+
78
+ tokenizer:
79
+ dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
80
+ type: ??? # Can be either bpe or wpe
81
+
82
+ preprocessor:
83
+ _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
84
+ sample_rate: ${model.sample_rate}
85
+ normalize: "per_feature"
86
+ window_size: 0.025
87
+ window_stride: 0.01
88
+ window: "hann"
89
+ features: &n_mels 80
90
+ n_fft: 512
91
+ frame_splicing: 1
92
+ dither: 0.00001
93
+ pad_to: 16
94
+ stft_conv: false
95
+
96
+ spec_augment:
97
+ _target_: nemo.collections.asr.modules.SpectrogramAugmentation
98
+ freq_masks: 2 # should be kept at 2
99
+ time_masks: 10 # can be 5 for small-med models, 10 for larger models.
100
+ freq_width: 27
101
+ time_width: 0.05
102
+
103
+ encoder:
104
+ _target_: nemo.collections.asr.modules.ConvASREncoder
105
+ feat_in: *n_mels
106
+ activation: swish
107
+ conv_mask: true
108
+ init_mode: "tds_uniform"
109
+
110
+ jasper:
111
+ - filters: ${model.model_defaults.filters}
112
+ repeat: 1
113
+ kernel: [5]
114
+ stride: [1]
115
+ dilation: [1]
116
+ dropout: 0.0
117
+ residual: false
118
+ separable: ${model.model_defaults.separable}
119
+ se: ${model.model_defaults.se}
120
+ se_context_size: ${model.model_defaults.se_context_size}
121
+
122
+ - filters: ${model.model_defaults.filters}
123
+ repeat: ${model.model_defaults.repeat}
124
+ kernel: [5]
125
+ stride: [1]
126
+ dilation: [1]
127
+ dropout: ${model.model_defaults.dropout}
128
+ residual: true
129
+ separable: ${model.model_defaults.separable}
130
+ se: ${model.model_defaults.se}
131
+ se_context_size: ${model.model_defaults.se_context_size}
132
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
133
+
134
+ - filters: ${model.model_defaults.filters}
135
+ repeat: ${model.model_defaults.repeat}
136
+ kernel: [5]
137
+ stride: [1]
138
+ dilation: [1]
139
+ dropout: ${model.model_defaults.dropout}
140
+ residual: true
141
+ separable: ${model.model_defaults.separable}
142
+ se: ${model.model_defaults.se}
143
+ se_context_size: ${model.model_defaults.se_context_size}
144
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
145
+
146
+ - filters: ${model.model_defaults.filters}
147
+ repeat: ${model.model_defaults.repeat}
148
+ kernel: [5]
149
+ stride: [2]
150
+ dilation: [1]
151
+ dropout: ${model.model_defaults.dropout}
152
+ residual: true
153
+ separable: ${model.model_defaults.separable}
154
+ se: ${model.model_defaults.se}
155
+ se_context_size: ${model.model_defaults.se_context_size}
156
+ stride_last: true
157
+ residual_mode: "stride_add"
158
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
159
+
160
+ - filters: ${model.model_defaults.filters}
161
+ repeat: ${model.model_defaults.repeat}
162
+ kernel: [5]
163
+ stride: [1]
164
+ dilation: [1]
165
+ dropout: ${model.model_defaults.dropout}
166
+ residual: true
167
+ separable: ${model.model_defaults.separable}
168
+ se: ${model.model_defaults.se}
169
+ se_context_size: ${model.model_defaults.se_context_size}
170
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
171
+
172
+ - filters: ${model.model_defaults.filters}
173
+ repeat: ${model.model_defaults.repeat}
174
+ kernel: [5]
175
+ stride: [1]
176
+ dilation: [1]
177
+ dropout: ${model.model_defaults.dropout}
178
+ residual: true
179
+ separable: ${model.model_defaults.separable}
180
+ se: ${model.model_defaults.se}
181
+ se_context_size: ${model.model_defaults.se_context_size}
182
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
183
+
184
+ - filters: ${model.model_defaults.filters}
185
+ repeat: ${model.model_defaults.repeat}
186
+ kernel: [5]
187
+ stride: [1]
188
+ dilation: [1]
189
+ dropout: ${model.model_defaults.dropout}
190
+ residual: true
191
+ separable: ${model.model_defaults.separable}
192
+ se: ${model.model_defaults.se}
193
+ se_context_size: ${model.model_defaults.se_context_size}
194
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
195
+
196
+ - filters: ${model.model_defaults.filters}
197
+ repeat: ${model.model_defaults.repeat}
198
+ kernel: [5]
199
+ stride: [2] # *stride
200
+ dilation: [1]
201
+ dropout: ${model.model_defaults.dropout}
202
+ residual: true
203
+ separable: ${model.model_defaults.separable}
204
+ se: ${model.model_defaults.se}
205
+ se_context_size: ${model.model_defaults.se_context_size}
206
+ stride_last: true
207
+ residual_mode: "stride_add"
208
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
209
+
210
+ - filters: ${model.model_defaults.filters}
211
+ repeat: ${model.model_defaults.repeat}
212
+ kernel: [5]
213
+ stride: [1]
214
+ dilation: [1]
215
+ dropout: ${model.model_defaults.dropout}
216
+ residual: true
217
+ separable: ${model.model_defaults.separable}
218
+ se: ${model.model_defaults.se}
219
+ se_context_size: ${model.model_defaults.se_context_size}
220
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
221
+
222
+ - filters: ${model.model_defaults.filters}
223
+ repeat: ${model.model_defaults.repeat}
224
+ kernel: [5]
225
+ stride: [1]
226
+ dilation: [1]
227
+ dropout: ${model.model_defaults.dropout}
228
+ residual: true
229
+ separable: ${model.model_defaults.separable}
230
+ se: ${model.model_defaults.se}
231
+ se_context_size: ${model.model_defaults.se_context_size}
232
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
233
+
234
+ - filters: ${model.model_defaults.filters}
235
+ repeat: ${model.model_defaults.repeat}
236
+ kernel: [5]
237
+ stride: [1]
238
+ dilation: [1]
239
+ dropout: ${model.model_defaults.dropout}
240
+ residual: true
241
+ separable: ${model.model_defaults.separable}
242
+ se: ${model.model_defaults.se}
243
+ se_context_size: ${model.model_defaults.se_context_size}
244
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
245
+
246
+ - filters: ${model.model_defaults.filters}
247
+ repeat: ${model.model_defaults.repeat}
248
+ kernel: [5]
249
+ stride: [1]
250
+ dilation: [1]
251
+ dropout: ${model.model_defaults.dropout}
252
+ residual: true
253
+ separable: ${model.model_defaults.separable}
254
+ se: ${model.model_defaults.se}
255
+ se_context_size: ${model.model_defaults.se_context_size}
256
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
257
+
258
+ - filters: ${model.model_defaults.filters}
259
+ repeat: ${model.model_defaults.repeat}
260
+ kernel: [5]
261
+ stride: [1]
262
+ dilation: [1]
263
+ dropout: ${model.model_defaults.dropout}
264
+ residual: true
265
+ separable: ${model.model_defaults.separable}
266
+ se: ${model.model_defaults.se}
267
+ se_context_size: ${model.model_defaults.se_context_size}
268
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
269
+
270
+ - filters: ${model.model_defaults.filters}
271
+ repeat: ${model.model_defaults.repeat}
272
+ kernel: [5]
273
+ stride: [1]
274
+ dilation: [1]
275
+ dropout: ${model.model_defaults.dropout}
276
+ residual: true
277
+ separable: ${model.model_defaults.separable}
278
+ se: ${model.model_defaults.se}
279
+ se_context_size: ${model.model_defaults.se_context_size}
280
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
281
+
282
+ - filters: ${model.model_defaults.filters}
283
+ repeat: ${model.model_defaults.repeat}
284
+ kernel: [5]
285
+ stride: [2] # stride
286
+ dilation: [1]
287
+ dropout: ${model.model_defaults.dropout}
288
+ residual: true
289
+ separable: ${model.model_defaults.separable}
290
+ se: ${model.model_defaults.se}
291
+ se_context_size: ${model.model_defaults.se_context_size}
292
+ stride_last: true
293
+ residual_mode: "stride_add"
294
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
295
+
296
+ - filters: ${model.model_defaults.filters}
297
+ repeat: ${model.model_defaults.repeat}
298
+ kernel: [5]
299
+ stride: [1]
300
+ dilation: [1]
301
+ dropout: ${model.model_defaults.dropout}
302
+ residual: true
303
+ separable: ${model.model_defaults.separable}
304
+ se: ${model.model_defaults.se}
305
+ se_context_size: ${model.model_defaults.se_context_size}
306
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
307
+
308
+ - filters: ${model.model_defaults.filters}
309
+ repeat: ${model.model_defaults.repeat}
310
+ kernel: [5]
311
+ stride: [1]
312
+ dilation: [1]
313
+ dropout: ${model.model_defaults.dropout}
314
+ residual: true
315
+ separable: ${model.model_defaults.separable}
316
+ se: ${model.model_defaults.se}
317
+ se_context_size: ${model.model_defaults.se_context_size}
318
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
319
+
320
+ - filters: ${model.model_defaults.filters}
321
+ repeat: ${model.model_defaults.repeat}
322
+ kernel: [5]
323
+ stride: [1]
324
+ dilation: [1]
325
+ dropout: ${model.model_defaults.dropout}
326
+ residual: true
327
+ separable: ${model.model_defaults.separable}
328
+ se: ${model.model_defaults.se}
329
+ se_context_size: ${model.model_defaults.se_context_size}
330
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
331
+
332
+ - filters: ${model.model_defaults.filters}
333
+ repeat: ${model.model_defaults.repeat}
334
+ kernel: [5]
335
+ stride: [1]
336
+ dilation: [1]
337
+ dropout: ${model.model_defaults.dropout}
338
+ residual: true
339
+ separable: ${model.model_defaults.separable}
340
+ se: ${model.model_defaults.se}
341
+ se_context_size: ${model.model_defaults.se_context_size}
342
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
343
+
344
+ - filters: ${model.model_defaults.filters}
345
+ repeat: ${model.model_defaults.repeat}
346
+ kernel: [5]
347
+ stride: [1]
348
+ dilation: [1]
349
+ dropout: ${model.model_defaults.dropout}
350
+ residual: true
351
+ separable: ${model.model_defaults.separable}
352
+ se: ${model.model_defaults.se}
353
+ se_context_size: ${model.model_defaults.se_context_size}
354
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
355
+
356
+ - filters: ${model.model_defaults.filters}
357
+ repeat: ${model.model_defaults.repeat}
358
+ kernel: [5]
359
+ stride: [1]
360
+ dilation: [1]
361
+ dropout: ${model.model_defaults.dropout}
362
+ residual: true
363
+ separable: ${model.model_defaults.separable}
364
+ se: ${model.model_defaults.se}
365
+ se_context_size: ${model.model_defaults.se_context_size}
366
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
367
+
368
+ - filters: ${model.model_defaults.filters}
369
+ repeat: ${model.model_defaults.repeat}
370
+ kernel: [5]
371
+ stride: [1]
372
+ dilation: [1]
373
+ dropout: ${model.model_defaults.dropout}
374
+ residual: true
375
+ separable: ${model.model_defaults.separable}
376
+ se: ${model.model_defaults.se}
377
+ se_context_size: ${model.model_defaults.se_context_size}
378
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
379
+
380
+ - filters: ${model.model_defaults.enc_hidden}
381
+ repeat: 1
382
+ kernel: [5]
383
+ stride: [1]
384
+ dilation: [1]
385
+ dropout: 0.0
386
+ residual: false
387
+ separable: ${model.model_defaults.separable}
388
+ se: ${model.model_defaults.se}
389
+ se_context_size: ${model.model_defaults.se_context_size}
390
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
391
+
392
+
393
+ decoder:
394
+ _target_: nemo.collections.asr.modules.RNNTDecoder
395
+ normalization_mode: null # Currently only null is supported for export.
396
+ random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
397
+ blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.
398
+
399
+ prednet:
400
+ pred_hidden: ${model.model_defaults.pred_hidden}
401
+ pred_rnn_layers: 1 # only 1 layer LSTM networks are exportable.
402
+ t_max: null # Maximum possible target seq length used for Chrono Initialization - https://arxiv.org/abs/1804.11188. Disabled by default.
403
+ dropout: 0.1
404
+
405
+ joint:
406
+ _target_: nemo.collections.asr.modules.RNNTJoint
407
+ log_softmax: null # sets it according to cpu/gpu device
408
+ preserve_memory: false # dramatically slows down training, but might preserve some memory
409
+
410
+ # Fuses the computation of prediction net + joint net + loss + WER calculation
411
+ # to be run on sub-batches of size `fused_batch_size`.
412
+ # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
413
+ # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
414
+ # Using small values here will preserve a lot of memory during training, but will make training slower as well.
415
+ # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
416
+ # However, to preserve memory, this ratio can be 1:8 or even 1:16.
417
+ # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
418
+ fuse_loss_wer: true
419
+ fused_batch_size: 16
420
+
421
+ jointnet:
422
+ joint_hidden: ${model.model_defaults.joint_hidden}
423
+ activation: "relu"
424
+ dropout: 0.1
425
+
426
+ # RNNT decoding strategy
427
+ decoding:
428
+ strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd.
429
+
430
+ # greedy strategy config
431
+ greedy:
432
+ max_symbols: 10
433
+
434
+ # beam strategy config
435
+ beam:
436
+ beam_size: 4
437
+ score_norm: true
438
+ return_best_hypothesis: False
439
+ softmax_temperature: 1.0 # scale the logits by some temperature prior to softmax
440
+ tsd_max_sym_exp: 10 # for Time Synchronous Decoding, int > 0
441
+ alsd_max_target_len: 5.0 # for Alignment-Length Synchronous Decoding, float > 1.0
442
+ maes_num_steps: 2 # for modified Adaptive Expansion Search, int > 0
443
+ maes_prefix_alpha: 1 # for modified Adaptive Expansion Search, int > 0
444
+ maes_expansion_beta: 2 # for modified Adaptive Expansion Search, int >= 0
445
+ maes_expansion_gamma: 2.3 # for modified Adaptive Expansion Search, float >= 0
446
+
447
+ # RNNT loss config
448
+ loss:
449
+ loss_name: "default"
450
+
451
+ warprnnt_numba_kwargs:
452
+ # FastEmit regularization: https://arxiv.org/abs/2010.11148
453
+ fastemit_lambda: 0.001 # Values can be in range [1e-4, 1e-2]. Generally, 0.001 is good start.
454
+ clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.
455
+
456
+ optim:
457
+ name: novograd
458
+ lr: 0.05
459
+
460
+ # optimizer arguments
461
+ betas: [0.9, 0.0]
462
+ weight_decay: 0.001
463
+
464
+ # scheduler setup
465
+ sched:
466
+ name: CosineAnnealing
467
+
468
+ # scheduler config override
469
+ warmup_steps: 5000
470
+ warmup_ratio: null
471
+ min_lr: 1e-6
472
+ last_epoch: -1
conf/contextnet_rnnt_dummy.yaml ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This config contains the values for training a dummy ContextNet model with Transducer loss and BPE-based vocabulary.
2
+ # In contrast to original ContextNet, the same number of filters is used throughout the model.
3
+ # To train it with smaller batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
4
+
5
+ # It contains the default values for training a ContextNet ASR model, dummy size, with Transducer loss and sub-word encoding.
6
+
7
+ name: &name "ContextNet-8x-Stride-RNNT"
8
+
9
+ model:
10
+ sample_rate: 16000
11
+ compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
12
+
13
+ train_ds:
14
+ manifest_filepath: ???
15
+ sample_rate: ${model.sample_rate}
16
+ batch_size: 4 # Can be increased if memory allows or when using smaller model
17
+ trim_silence: false
18
+ max_duration: 16.7
19
+ shuffle: true
20
+ use_start_end_token: false
21
+ num_workers: 16
22
+ pin_memory: true
23
+ # tarred datasets
24
+ is_tarred: false
25
+ tarred_audio_filepaths: null
26
+ tarred_shard_strategy: "scatter"
27
+ shuffle_n: 2048
28
+ # bucketing params
29
+ bucketing_strategy: "synced_randomized"
30
+ bucketing_batch_size: null
31
+ validation_ds:
32
+ manifest_filepath: ???
33
+ sample_rate: ${model.sample_rate}
34
+ batch_size: 8
35
+ shuffle: false
36
+ use_start_end_token: false
37
+ num_workers: 16
38
+ pin_memory: true
39
+
40
+ test_ds:
41
+ manifest_filepath: null
42
+ sample_rate: ${model.sample_rate}
43
+ batch_size: 8
44
+ shuffle: false
45
+ use_start_end_token: false
46
+ num_workers: 16
47
+ pin_memory: true
48
+
49
+ model_defaults:
50
+ filters: 64
51
+ repeat: 1
52
+ dropout: 0.1
53
+ separable: true
54
+ se: true
55
+ se_context_size: -1
56
+ kernel_size_factor: 1.0
57
+ # encoder / decoder / joint values
58
+ enc_hidden: 64
59
+ pred_hidden: 64
60
+ joint_hidden: 64
61
+
62
+ tokenizer:
63
+ dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
64
+ type: ??? # Can be either bpe or wpe
65
+
66
+ preprocessor:
67
+ _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
68
+ sample_rate: ${model.sample_rate}
69
+ normalize: "per_feature"
70
+ window_size: 0.025
71
+ window_stride: 0.01
72
+ window: "hann"
73
+ features: &n_mels 80
74
+ n_fft: 512
75
+ frame_splicing: 1
76
+ dither: 0.00001
77
+ pad_to: 16
78
+ stft_conv: false
79
+
80
+ spec_augment:
81
+ _target_: nemo.collections.asr.modules.SpectrogramAugmentation
82
+ freq_masks: 2 # should be kept at 2
83
+ time_masks: 10 # can be 5 for small-med models, 10 for larger models.
84
+ freq_width: 27
85
+ time_width: 0.05
86
+
87
+ encoder:
88
+ _target_: nemo.collections.asr.modules.ConvASREncoder
89
+ feat_in: *n_mels
90
+ activation: swish
91
+ conv_mask: true
92
+ init_mode: "tds_uniform"
93
+
94
+ jasper:
95
+ - filters: ${model.model_defaults.filters}
96
+ repeat: 1
97
+ kernel: [5]
98
+ stride: [1]
99
+ dilation: [1]
100
+ dropout: 0.0
101
+ residual: false
102
+ separable: ${model.model_defaults.separable}
103
+ se: ${model.model_defaults.se}
104
+ se_context_size: ${model.model_defaults.se_context_size}
105
+
106
+ - filters: ${model.model_defaults.filters}
107
+ repeat: ${model.model_defaults.repeat}
108
+ kernel: [5]
109
+ stride: [1]
110
+ dilation: [1]
111
+ dropout: 0.0
112
+ residual: true
113
+ separable: ${model.model_defaults.separable}
114
+ se: ${model.model_defaults.se}
115
+ se_context_size: ${model.model_defaults.se_context_size}
116
+ kernel_size_factor: ${model.model_defaults.kernel_size_factor}
117
+
118
+ decoder:
119
+ _target_: nemo.collections.asr.modules.RNNTDecoder
120
+ normalization_mode: null # Currently only null is supported for export.
121
+ random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
122
+ blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.
123
+
124
+ prednet:
125
+ pred_hidden: ${model.model_defaults.pred_hidden}
126
+ pred_rnn_layers: 1 # only 1 layer LSTM networks are exportable.
127
+ t_max: null # Maximum possible target seq length used for Chrono Initialization - https://arxiv.org/abs/1804.11188. Disabled by default.
128
+ dropout: 0.1
129
+
130
+ joint:
131
+ _target_: nemo.collections.asr.modules.RNNTJoint
132
+ log_softmax: null # sets it according to cpu/gpu device
133
+ preserve_memory: false # dramatically slows down training, but might preserve some memory
134
+
135
+ # Fuses the computation of prediction net + joint net + loss + WER calculation
136
+ # to be run on sub-batches of size `fused_batch_size`.
137
+ # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size.
138
+ # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss.
139
+ # Using small values here will preserve a lot of memory during training, but will make training slower as well.
140
+ # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1.
141
+ # However, to preserve memory, this ratio can be 1:8 or even 1:16.
142
+ # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow.
143
+ fuse_loss_wer: true
144
+ fused_batch_size: 16
145
+
146
+ jointnet:
147
+ joint_hidden: ${model.model_defaults.joint_hidden}
148
+ activation: "relu"
149
+ dropout: 0.1
150
+
151
+ # RNNT decoding strategy
152
+ decoding:
153
+ strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd.
154
+
155
+ # greedy strategy config
156
+ greedy:
157
+ max_symbols: 10
158
+
159
+ # beam strategy config
160
+ beam:
161
+ beam_size: 4
162
+ score_norm: true
163
+ return_best_hypothesis: False
164
+ softmax_temperature: 1.0 # scale the logits by some temperature prior to softmax
165
+ tsd_max_sym_exp: 10 # for Time Synchronous Decoding, int > 0
166
+ alsd_max_target_len: 5.0 # for Alignment-Length Synchronous Decoding, float > 1.0
167
+ maes_num_steps: 2 # for modified Adaptive Expansion Search, int > 0
168
+ maes_prefix_alpha: 1 # for modified Adaptive Expansion Search, int > 0
169
+ maes_expansion_beta: 2 # for modified Adaptive Expansion Search, int >= 0
170
+ maes_expansion_gamma: 2.3 # for modified Adaptive Expansion Search, float >= 0
171
+
172
+ # RNNT loss config
173
+ loss:
174
+ loss_name: "default"
175
+
176
+ warprnnt_numba_kwargs:
177
+ # FastEmit regularization: https://arxiv.org/abs/2010.11148
178
+ fastemit_lambda: 0.001 # Values can be in range [1e-4, 1e-2]. Generally, 0.001 is good start.
179
+ clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.
180
+
181
+ optim:
182
+ name: novograd
183
+ lr: 0.05
184
+
185
+ # optimizer arguments
186
+ betas: [0.9, 0.0]
187
+ weight_decay: 0.001
188
+
189
+ # scheduler setup
190
+ sched:
191
+ name: CosineAnnealing
192
+
193
+ # scheduler config override
194
+ warmup_steps: 5000
195
+ warmup_ratio: null
196
+ min_lr: 1e-6
197
+ last_epoch: -1
eval_results.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 7.38,
3
+ "eval_loss": 8.706663131713867,
4
+ "eval_runtime": 970.2156,
5
+ "eval_samples": 13098,
6
+ "eval_samples_per_second": 13.5,
7
+ "eval_steps_per_second": 3.376,
8
+ "eval_wer": 0.20430683297635546
9
+ }
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_rnnt import RNNTBPEModel
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (197 Bytes). View file
 
models/__pycache__/modeling_rnnt.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
models/modeling_rnnt.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from nemo.collections.asr.models import EncDecRNNTBPEModel
6
+ from omegaconf import DictConfig
7
+ from transformers.utils import ModelOutput
8
+
9
+
10
+ @dataclass
11
+ class RNNTOutput(ModelOutput):
12
+ """
13
+ Base class for RNNT outputs.
14
+ """
15
+
16
+ loss: Optional[torch.FloatTensor] = None
17
+ wer: Optional[float] = None
18
+ wer_num: Optional[float] = None
19
+ wer_denom: Optional[float] = None
20
+
21
+
22
+ # Adapted from https://github.com/NVIDIA/NeMo/blob/66c7677cd4a68d78965d4905dd1febbf5385dff3/nemo/collections/asr/models/rnnt_bpe_models.py#L33
23
+ class RNNTBPEModel(EncDecRNNTBPEModel):
24
+ def __init__(self, cfg: DictConfig):
25
+ super().__init__(cfg=cfg, trainer=None)
26
+
27
+ def encoding(
28
+ self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
29
+ ):
30
+ """
31
+ Forward pass of the acoustic model. Note that for RNNT Models, the forward pass of the model is a 3 step process,
32
+ and this method only performs the first step - forward of the acoustic model.
33
+
34
+ Please refer to the `forward` in order to see the full `forward` step for training - which
35
+ performs the forward of the acoustic model, the prediction network and then the joint network.
36
+ Finally, it computes the loss and possibly compute the detokenized text via the `decoding` step.
37
+
38
+ Please refer to the `validation_step` in order to see the full `forward` step for inference - which
39
+ performs the forward of the acoustic model, the prediction network and then the joint network.
40
+ Finally, it computes the decoded tokens via the `decoding` step and possibly compute the batch metrics.
41
+
42
+ Args:
43
+ input_signal: Tensor that represents a batch of raw audio signals,
44
+ of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
45
+ `self.sample_rate` number of floating point values.
46
+ input_signal_length: Vector of length B, that contains the individual lengths of the audio
47
+ sequences.
48
+ processed_signal: Tensor that represents a batch of processed audio signals,
49
+ of shape (B, D, T) that has undergone processing via some DALI preprocessor.
50
+ processed_signal_length: Vector of length B, that contains the individual lengths of the
51
+ processed audio sequences.
52
+
53
+ Returns:
54
+ A tuple of 2 elements -
55
+ 1) The log probabilities tensor of shape [B, T, D].
56
+ 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
57
+ """
58
+ has_input_signal = input_signal is not None and input_signal_length is not None
59
+ has_processed_signal = processed_signal is not None and processed_signal_length is not None
60
+ if (has_input_signal ^ has_processed_signal) is False:
61
+ raise ValueError(
62
+ f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
63
+ " with ``processed_signal`` and ``processed_signal_len`` arguments."
64
+ )
65
+
66
+ if not has_processed_signal:
67
+ processed_signal, processed_signal_length = self.preprocessor(
68
+ input_signal=input_signal, length=input_signal_length,
69
+ )
70
+
71
+ # Spec augment is not applied during evaluation/testing
72
+ if self.spec_augmentation is not None and self.training:
73
+ processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)
74
+
75
+ encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
76
+ return encoded, encoded_len
77
+
78
+ def forward(self, input_ids, input_lengths=None, labels=None, label_lengths=None):
79
+ # encoding() only performs encoder forward
80
+ encoded, encoded_len = self.encoding(input_signal=input_ids, input_signal_length=input_lengths)
81
+ del input_ids
82
+
83
+ # During training, loss must be computed, so decoder forward is necessary
84
+ decoder, target_length, states = self.decoder(targets=labels, target_length=label_lengths)
85
+
86
+ # If experimental fused Joint-Loss-WER is not used
87
+ if not self.joint.fuse_loss_wer:
88
+ # Compute full joint and loss
89
+ joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)
90
+ loss_value = self.loss(
91
+ log_probs=joint, targets=labels, input_lengths=encoded_len, target_lengths=target_length
92
+ )
93
+ # Add auxiliary losses, if registered
94
+ loss_value = self.add_auxiliary_losses(loss_value)
95
+ wer = wer_num = wer_denom = None
96
+ if not self.training:
97
+ self.wer.update(encoded, encoded_len, labels, target_length)
98
+ wer, wer_num, wer_denom = self.wer.compute()
99
+ self.wer.reset()
100
+
101
+ else:
102
+ # If experimental fused Joint-Loss-WER is used
103
+ # Fused joint step
104
+ loss_value, wer, wer_num, wer_denom = self.joint(
105
+ encoder_outputs=encoded,
106
+ decoder_outputs=decoder,
107
+ encoder_lengths=encoded_len,
108
+ transcripts=labels,
109
+ transcript_lengths=label_lengths,
110
+ compute_wer=not self.training,
111
+ )
112
+ # Add auxiliary losses, if registered
113
+ loss_value = self.add_auxiliary_losses(loss_value)
114
+
115
+ return RNNTOutput(loss=loss_value, wer=wer, wer_num=wer_num, wer_denom=wer_denom)
process_asr_text_tokenizer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # USAGE: python process_asr_text_tokenizer.py --manifest=<path to train manifest files, seperated by commas> \
16
+ # --data_root="<output directory>" \
17
+ # --vocab_size=<number of tokens in vocabulary> \
18
+ # --tokenizer=<"spe" or "wpe"> \
19
+ # --log
20
+ # where <manifest> can be: train_clean_100, train_clean_360, train_other_500
21
+ # You can also put more than one data_set comma-separated:
22
+ # --manifest="train_clean_100,train_clean_360,train_other_500"
23
+ # or
24
+ # python process_asr_text_tokenizer.py --data_file=<path to train text file> \
25
+ # --data_root="<output directory>" \
26
+ # --vocab_size=<number of tokens in vocabulary> \
27
+ # --tokenizer=<"bpe" or "wpe"> \
28
+ # --log
29
+ # where <manifest> can be: train_clean_100, train_clean_360, train_other_500
30
+ # You can also put more than one data_set comma-separated:
31
+ # --manifest="train_clean_100,train_clean_360,train_other_500"
32
+ #
33
+ # Args:
34
+ # --manifest or --data_file: If your text data lies inside of an ASR manifest file,
35
+ # then use the --manifest path. If instead the text data is inside a file with separate lines
36
+ # corresponding to different text lines, then use --data_file.
37
+ # In either case, you can add commas to concatenate different manifests or different data files.
38
+ #
39
+ # --data_root: The output directory (whose subdirectories will be created if not present) where
40
+ # the tokenizers will be placed.
41
+ #
42
+ # --vocab_size: The size of the tokenizer vocabulary. Larger vocabularies can accommodate almost entire,
43
+ # words but the decoder size of any model will grow proportionally.
44
+ #
45
+ # --tokenizer: Can be either spe or wpe . spe refers to the Google sentencepiece library tokenizer.
46
+ # wpe refers to the HuggingFace BERT Word Piece tokenizer.
47
+ #
48
+ # --no_lower_case: When this flag is passed, it will force the tokenizer to create seperate tokens for
49
+ # upper and lower case characters. By default, the script will turn all the text to lower case
50
+ # before tokenization (and if upper case characters are passed during training/inference, the
51
+ # tokenizer will emit a token equivalent to Out-Of-Vocabulary). Used primarily for the
52
+ # English language.
53
+ #
54
+ # --spe_type: The sentencepiece library has a few implementations of the tokenization technique, and
55
+ # spe_type refers to these implementations. Currently supported types are unigram, bpe, char, word.
56
+ # Defaults to bpe.
57
+ #
58
+ # --spe_character_coverage: The sentencepiece library considers how much of the original vocabulary it
59
+ # should cover in its "base set" of tokens (akin to the lower and upper case characters of the
60
+ # English language). For almost all languages with small base token sets (<1000 tokens), this
61
+ # should be kept at its default of 1.0. For languages with larger vocabularies (say Japanese,
62
+ # Mandarin, Korean etc), the suggested value is 0.9995.
63
+ #
64
+ # --spe_sample_size: If the dataset is too large, consider using a sampled dataset indicated by a
65
+ # positive integer. By default, any negative value (default = -1) will use the entire dataset.
66
+ #
67
+ # --spe_train_extremely_large_corpus: When training a sentencepiece tokenizer on very large amounts of text,
68
+ # sometimes the tokenizer will run out of memory or wont be able to process so much data on RAM.
69
+ # At some point you might receive the following error - "Input corpus too large, try with
70
+ # train_extremely_large_corpus=true". If your machine has large amounts of RAM, it might still be possible
71
+ # to build the tokenizer using the above flag. Will silently fail if it runs out of RAM.
72
+ #
73
+ # --spe_max_sentencepiece_length: Limits the maximum length that any any SentencePiece subword can be.
74
+ # Using this will change the subword tokens generated.
75
+ #
76
+ # --spe_pad: Adds <pad> as special token.
77
+ #
78
+ # --spe_bos: Adds <s> as Begining-of-Sentence special token.
79
+ #
80
+ # --spe_eos: Adds </s> as End-of-Sentence special token.
81
+ #
82
+ # --log: Whether the script should display log messages
83
+
84
+ import json
85
+ import logging
86
+ import os
87
+
88
+ import tokenizers
89
+
90
+ from nemo.collections.common.tokenizers.sentencepiece_tokenizer import create_spt_model
91
+
92
+
93
+ def __build_document_from_manifests(
94
+ data_root: str, manifests: str,
95
+ ):
96
+ if ',' in manifests:
97
+ manifests = manifests.split(',')
98
+ else:
99
+ manifests = [manifests]
100
+
101
+ document_dir = os.path.join(data_root, 'text_corpus')
102
+ if not os.path.exists(document_dir):
103
+ os.makedirs(document_dir)
104
+
105
+ document_path = os.path.join(document_dir, 'document.txt')
106
+
107
+ if os.path.exists(document_path):
108
+ logging.info('Corpus already exists at path : %s', document_path)
109
+ return document_path
110
+
111
+ num_lines = 0
112
+ with open(document_path, 'w') as out_writer:
113
+ for manifest in manifests:
114
+ with open(manifest, 'r') as in_reader:
115
+ for line in in_reader:
116
+ item = json.loads(line)
117
+ text = item['text']
118
+
119
+ out_writer.write(text + '\n')
120
+ out_writer.flush()
121
+
122
+ num_lines += 1
123
+
124
+ logging.info(f"Finished extracting manifest : {manifest}")
125
+
126
+ logging.info("Finished extracting all manifests ! Number of sentences : {}".format(num_lines))
127
+ return document_path
128
+
129
+
130
+ def __process_data(
131
+ text_path: str,
132
+ dst_folder: str,
133
+ vocab_size: int,
134
+ tokenizer_type: str,
135
+ spe_type: str,
136
+ spe_character_coverage: float,
137
+ spe_train_extremely_large_corpus: bool,
138
+ spe_sample_size: int,
139
+ spe_max_sentencepiece_length: int,
140
+ spe_bos: bool,
141
+ spe_eos: bool,
142
+ spe_pad: bool,
143
+ lower_case: bool,
144
+ ):
145
+ """
146
+ Converts flac to wav and build manifests's json
147
+ Args:
148
+ text_path: source with text lines
149
+ dst_folder: where wav files will be stored
150
+ vocab_size: vocabular size used in encoding the text
151
+ tokenizer_type: type of tokenization to perform - wpe or spe
152
+ spe_type: type of tokenization model used for spe.
153
+ spe_character_coverage: float value between 0 and 1 (as a percentage). For languages with a vast charset,
154
+ can be < 1.0, but for all other languages, it should be set as 1.0
155
+ spe_sample_size: int, default of -1. If positive integer is used, samples the dataset
156
+ by given sample size.
157
+ spe_train_extremely_large_corpus: bool. If dataset is too large, and user has sufficient RAM,
158
+ this flag can be set to try to trained the tokenizer. Will silently fail if it runs out of RAM.
159
+ spe_max_sentencepiece_length: Limits the maximum length of the SentencePiece subword that can be constructed.
160
+ By default, no limit is placed.
161
+ spe_bos: Bool flag, whether to add <s> to SentencePiece tokenizer vocabulary.
162
+ spe_eos: Bool flag, whether to add </s> to SentencePiece tokenizer vocabulary.
163
+ spe_pad: Bool flag, whether to add <pad> to SentencePiece tokenizer vocabulary.
164
+ lower_case: whether to tokenize with lower case character set only (for english)
165
+
166
+ Returns:
167
+ """
168
+ if tokenizer_type == 'spe':
169
+
170
+ # Prepare directory of tokenizer
171
+ if spe_max_sentencepiece_length > 0:
172
+ tokenizer_dir = os.path.join(dst_folder, 'tokenizer_{}_{}_v{}_max_{}').format(
173
+ tokenizer_type, spe_type, vocab_size, spe_max_sentencepiece_length
174
+ )
175
+ else:
176
+ tokenizer_dir = os.path.join(dst_folder, 'tokenizer_{}_{}_v{}').format(
177
+ tokenizer_type, spe_type, vocab_size
178
+ )
179
+
180
+ if spe_pad:
181
+ tokenizer_dir = f'{tokenizer_dir}_pad'
182
+ if spe_bos:
183
+ tokenizer_dir = f'{tokenizer_dir}_bos'
184
+ if spe_eos:
185
+ tokenizer_dir = f'{tokenizer_dir}_eos'
186
+
187
+ if not os.path.exists(tokenizer_dir):
188
+ os.makedirs(tokenizer_dir)
189
+
190
+ if os.path.exists(os.path.join(tokenizer_dir, 'tokenizer.model')):
191
+ logging.warning("Model file already exists, overriding old model file !")
192
+ os.remove(os.path.join(tokenizer_dir, 'tokenizer.model'))
193
+
194
+ # Build tokenizer
195
+ tokenizer_path, vocab_path = create_spt_model(
196
+ data_file=text_path,
197
+ vocab_size=vocab_size,
198
+ sample_size=spe_sample_size,
199
+ do_lower_case=lower_case,
200
+ output_dir=tokenizer_dir,
201
+ tokenizer_type=spe_type,
202
+ character_coverage=spe_character_coverage,
203
+ train_extremely_large_corpus=spe_train_extremely_large_corpus,
204
+ max_sentencepiece_length=spe_max_sentencepiece_length,
205
+ bos=spe_bos,
206
+ eos=spe_eos,
207
+ pad=spe_pad,
208
+ )
209
+
210
+ else:
211
+ tokenizer_dir = os.path.join(dst_folder, 'tokenizer_{}_v{}').format(tokenizer_type, vocab_size)
212
+
213
+ if not os.path.exists(tokenizer_dir):
214
+ os.makedirs(tokenizer_dir)
215
+
216
+ tokenizer = tokenizers.BertWordPieceTokenizer(lowercase=lower_case)
217
+
218
+ tokenizer.train(text_path, vocab_size=vocab_size)
219
+ tokenizer.save_model(tokenizer_dir)
220
+
221
+ return tokenizer_dir
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pip install transformers
2
+ pip install datasets
3
+ pip install jiwer
4
+ pip install wandb
5
+ pip install soundfile
6
+ pip install librosa
7
+ pip install bitsandbytes
run_ami.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_rnnt.py \
3
+ --config_path="conf/conformer_transducer_bpe_xlarge.yaml" \
4
+ --model_name_or_path="stt_en_conformer_transducer_xlarge" \
5
+ --dataset_name="speech-seq2seq/ami" \
6
+ --tokenizer_path="tokenizer" \
7
+ --vocab_size="1024" \
8
+ --num_train_epochs="7.38" \
9
+ --dataset_config_name="ihm" \
10
+ --train_split_name="train" \
11
+ --eval_split_name="validation" \
12
+ --test_split_name="test" \
13
+ --text_column_name="text" \
14
+ --output_dir="./" \
15
+ --run_name="rnnt-ami-baseline" \
16
+ --wandb_project="rnnt" \
17
+ --per_device_train_batch_size="8" \
18
+ --per_device_eval_batch_size="4" \
19
+ --logging_steps="50" \
20
+ --learning_rate="1e-4" \
21
+ --warmup_steps="500" \
22
+ --save_strategy="steps" \
23
+ --save_steps="20000" \
24
+ --evaluation_strategy="steps" \
25
+ --eval_steps="20000" \
26
+ --report_to="wandb" \
27
+ --preprocessing_num_workers="4" \
28
+ --fused_batch_size="8" \
29
+ --length_column_name="input_lengths" \
30
+ --do_lower_case="False" \
31
+ --fuse_loss_wer \
32
+ --group_by_length \
33
+ --overwrite_output_dir \
34
+ --do_train \
35
+ --do_eval \
36
+ --do_predict \
37
+ --push_to_hub \
38
+ --use_auth_token
run_speech_recognition_rnnt.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning NVIDIA RNN-T models for speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+ import copy
21
+ import logging
22
+ import os
23
+ import re
24
+ import sys
25
+ from dataclasses import dataclass, field
26
+
27
+ from tqdm import tqdm
28
+ import json
29
+ from typing import Optional, Dict, Union, List
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ from omegaconf import OmegaConf, open_dict
36
+ from models import RNNTBPEModel
37
+ from nemo.core import adapter_mixins
38
+ from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig
39
+
40
+ import datasets
41
+ from datasets import DatasetDict, load_dataset
42
+ import transformers
43
+ from transformers import (
44
+ HfArgumentParser,
45
+ Seq2SeqTrainingArguments,
46
+ set_seed,
47
+ Trainer,
48
+ TrainerCallback,
49
+ TrainingArguments,
50
+ TrainerState,
51
+ TrainerControl,
52
+ )
53
+ from transformers.trainer_pt_utils import get_parameter_names
54
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
55
+ from transformers.utils import check_min_version
56
+ from transformers.utils.versions import require_version
57
+
58
+ from process_asr_text_tokenizer import __process_data as nemo_process_data, \
59
+ __build_document_from_manifests as nemo_build_document_from_manifests
60
+
61
+ import bitsandbytes as bnb
62
+
63
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
64
+ check_min_version("4.17.0.dev0")
65
+
66
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ @dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
75
+ """
76
+
77
+ config_path: str = field(
78
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."},
79
+ )
80
+ model_name_or_path: Optional[str] = field(
81
+ default=None,
82
+ metadata={"help": "Path to pretrained model or model identifier from NVIDIA NeMo NGC."}
83
+ )
84
+ cache_dir: Optional[str] = field(
85
+ default=None,
86
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co or NVIDIA NeMo NGC."},
87
+ )
88
+ use_auth_token: bool = field(
89
+ default=False,
90
+ metadata={
91
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
92
+ "with private models)."
93
+ },
94
+ )
95
+ manifest_path: str = field(
96
+ default="data",
97
+ metadata={
98
+ "help": "Manifest path."
99
+ },
100
+ )
101
+ tokenizer_path: str = field(
102
+ default="tokenizers",
103
+ metadata={
104
+ "help": "Tokenizer path."
105
+ },
106
+ )
107
+ vocab_size: int = field(
108
+ default=1024,
109
+ metadata={"help": "Tokenizer vocab size."}
110
+ )
111
+ tokenizer_type: str = field(
112
+ default="spe",
113
+ metadata={
114
+ "help": "Can be either spe or wpe. spe refers to the Google sentencepiece library tokenizer."
115
+ "wpe refers to the HuggingFace BERT Word Piece tokenizer."
116
+ },
117
+ )
118
+ spe_type: str = field(
119
+ default="bpe",
120
+ metadata={
121
+ "help": "Type of the SentencePiece model. Can be `bpe`, `unigram`, `char` or `word`."
122
+ "Used only if `tokenizer_type` == `spe`"
123
+ },
124
+ )
125
+ cutoff_freq: str = field(
126
+ default=0.001,
127
+ metadata={"help": "Drop the least frequent chars from the train set when building the tokenizer."}
128
+ )
129
+ fuse_loss_wer: bool = field(
130
+ default=True,
131
+ metadata={
132
+ "help": "Whether to fuse the computation of prediction net + joint net + loss + WER calculation to be run "
133
+ "on sub-batches of size `fused_batch_size`"
134
+ }
135
+ )
136
+ fused_batch_size: int = field(
137
+ default=8,
138
+ metadata={
139
+ "help": "`fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss."
140
+ "Using small values here will preserve a lot of memory during training, but will make training slower as well."
141
+ "An optimal ratio of fused_batch_size : per_device_train_batch_size is 1:1."
142
+ "However, to preserve memory, this ratio can be 1:8 or even 1:16."
143
+ }
144
+ )
145
+ final_decoding_strategy: str = field(
146
+ default="greedy_batch",
147
+ metadata={
148
+ "help": "Decoding strategy for final eval/prediction steps. One of: [`greedy`, `greedy_batch`, `beam`, "
149
+ "`tsd`, `alsd`]."
150
+ }
151
+ )
152
+ final_num_beams: int = field(
153
+ default=1,
154
+ metadata={
155
+ "help": "Number of beams for final eval/prediction steps. Increase beam size for better scores, "
156
+ "but it will take much longer for transcription!"
157
+ }
158
+ )
159
+ freeze_encoder: bool = field(
160
+ default=False,
161
+ metadata={"help": "Freeze the acoustic encoder of the model. Recommend when fine-tuning on small datasets."}
162
+ )
163
+ unfreeze_encoder: bool = field(
164
+ default=False,
165
+ metadata={"help": "Unfreeze the acoustic encoder of the model after first evaluation step."}
166
+ )
167
+ add_adapter: bool = field(
168
+ default=False,
169
+ metadata={"help": "Add an adapter layer to the encoder of the model."}
170
+ )
171
+ use_adam8bit: bool = field(
172
+ default=False,
173
+ metadata={"help": "Whether to use bitsandbytes 8bit AdamW optimiser."}
174
+ )
175
+
176
+
177
+ @dataclass
178
+ class DataTrainingArguments:
179
+ """
180
+ Arguments pertaining to what data we are going to input our model for training and eval.
181
+ """
182
+
183
+ dataset_name: str = field(
184
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
185
+ )
186
+ dataset_config_name: Optional[str] = field(
187
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
188
+ )
189
+ text_column: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
192
+ )
193
+ dataset_cache_dir: Optional[str] = field(
194
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
195
+ )
196
+ overwrite_cache: bool = field(
197
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
198
+ )
199
+ preprocessing_num_workers: Optional[int] = field(
200
+ default=None,
201
+ metadata={"help": "The number of processes to use for the preprocessing."},
202
+ )
203
+ max_train_samples: Optional[int] = field(
204
+ default=None,
205
+ metadata={
206
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
207
+ "value if set."
208
+ },
209
+ )
210
+ max_eval_samples: Optional[int] = field(
211
+ default=None,
212
+ metadata={
213
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
214
+ "value if set."
215
+ },
216
+ )
217
+ max_predict_samples: Optional[int] = field(
218
+ default=None,
219
+ metadata={
220
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
221
+ "value if set."
222
+ },
223
+ )
224
+ audio_column_name: str = field(
225
+ default="audio",
226
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
227
+ )
228
+ text_column_name: str = field(
229
+ default="text",
230
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
231
+ )
232
+ max_duration_in_seconds: float = field(
233
+ default=20.0,
234
+ metadata={
235
+ "help": "Truncate training audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
236
+ },
237
+ )
238
+ min_duration_in_seconds: float = field(
239
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
240
+ )
241
+ max_eval_duration_in_seconds: float = field(
242
+ default=None,
243
+ metadata={
244
+ "help": "Truncate eval/test audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
245
+ },
246
+ )
247
+ max_target_length: Optional[int] = field(
248
+ default=128,
249
+ metadata={
250
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
251
+ "than this will be truncated, sequences shorter will be padded."
252
+ },
253
+ )
254
+ min_target_length: Optional[int] = field(
255
+ default=2,
256
+ metadata={
257
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
258
+ "than this will be filtered."
259
+ },
260
+ )
261
+ preprocessing_only: bool = field(
262
+ default=False,
263
+ metadata={
264
+ "help": "Whether to only do data preprocessing and skip training. "
265
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
266
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
267
+ "so that the cached datasets can consequently be loaded in distributed training"
268
+ },
269
+ )
270
+ train_split_name: str = field(
271
+ default="train",
272
+ metadata={
273
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
274
+ },
275
+ )
276
+ eval_split_name: str = field(
277
+ default="validation",
278
+ metadata={
279
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
280
+ },
281
+ )
282
+ test_split_name: str = field(
283
+ default="test",
284
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
285
+ )
286
+ do_lower_case: bool = field(
287
+ default=True,
288
+ metadata={"help": "Whether the target text should be lower cased."},
289
+ )
290
+ wandb_project: str = field(
291
+ default="speech-recognition-rnnt",
292
+ metadata={"help": "The name of the wandb project."},
293
+ )
294
+
295
+
296
+ def build_tokenizer(model_args, data_args, manifests):
297
+ """
298
+ Function to build a NeMo tokenizer from manifest file(s).
299
+ Copied from https://github.com/NVIDIA/NeMo/blob/66c7677cd4a68d78965d4905dd1febbf5385dff3/scripts/tokenizers/process_asr_text_tokenizer.py#L268
300
+ """
301
+ data_root = model_args.tokenizer_path
302
+ if isinstance(manifests, list):
303
+ joint_manifests = ",".join(manifests)
304
+ else:
305
+ joint_manifests = manifests
306
+ vocab_size = model_args.vocab_size
307
+ tokenizer = model_args.tokenizer_type
308
+ spe_type = model_args.spe_type
309
+ if not 0 <= model_args.cutoff_freq < 1:
310
+ raise ValueError(f"`cutoff_freq` must be between zero and one, got {model_args.cutoff_freq}")
311
+ spe_character_coverage = 1 - model_args.cutoff_freq
312
+
313
+ logger.info("Building tokenizer...")
314
+ if not os.path.exists(data_root):
315
+ os.makedirs(data_root)
316
+
317
+ text_corpus_path = nemo_build_document_from_manifests(data_root, joint_manifests)
318
+
319
+ tokenizer_path = nemo_process_data(
320
+ text_corpus_path,
321
+ data_root,
322
+ vocab_size,
323
+ tokenizer,
324
+ spe_type,
325
+ lower_case=data_args.do_lower_case,
326
+ spe_character_coverage=spe_character_coverage,
327
+ spe_sample_size=-1,
328
+ spe_train_extremely_large_corpus=False,
329
+ spe_max_sentencepiece_length=-1,
330
+ spe_bos=False,
331
+ spe_eos=False,
332
+ spe_pad=False,
333
+ )
334
+
335
+ print("Serialized tokenizer at location :", tokenizer_path)
336
+ logger.info('Done!')
337
+
338
+ # Tokenizer path
339
+ if tokenizer == 'spe':
340
+ tokenizer_dir = os.path.join(data_root, f"tokenizer_spe_{spe_type}_v{vocab_size}")
341
+ tokenizer_type_cfg = "bpe"
342
+ else:
343
+ tokenizer_dir = os.path.join(data_root, f"tokenizer_wpe_v{vocab_size}")
344
+ tokenizer_type_cfg = "wpe"
345
+
346
+ return tokenizer_dir, tokenizer_type_cfg
347
+
348
+
349
+ def NeMoDataCollator(features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
350
+ """
351
+ Data collator that will dynamically pad the inputs received.
352
+ Since NeMo models don't have a HF processor defined (feature extractor + tokenizer), we'll pad by hand...
353
+ The padding idx is arbitrary: we provide the model with the input lengths and label lengths, from which
354
+ all the relevant padding information is inferred. Thus, we'll use the default np.pad padding idx (0).
355
+ """
356
+ # split inputs and labels since they have to be of different lengths
357
+ # and need different padding methods
358
+ input_ids = [feature["input_ids"] for feature in features]
359
+ labels = [feature["labels"] for feature in features]
360
+
361
+ # first, pad the audio inputs to max_len
362
+ input_lengths = [feature["input_lengths"] for feature in features]
363
+ max_input_len = max(input_lengths)
364
+ input_ids = [np.pad(input_val, (0, max_input_len - input_len), 'constant') for input_val, input_len in
365
+ zip(input_ids, input_lengths)]
366
+
367
+ # next, pad the target labels to max_len
368
+ label_lengths = [len(lab) for lab in labels]
369
+ max_label_len = max(label_lengths)
370
+ labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant') for lab, lab_len in zip(labels, label_lengths)]
371
+
372
+ batch = {"input_lengths": input_lengths, "labels": labels, "label_lengths": label_lengths}
373
+
374
+ # return batch as a pt tensor (list -> np.array -> torch.tensor)
375
+ batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
376
+
377
+ # leave all ints as are, convert float64 to pt float
378
+ batch["input_ids"] = torch.tensor(np.array(input_ids, dtype=np.float32), requires_grad=False)
379
+
380
+ return batch
381
+
382
+
383
+ def main():
384
+ # See all possible arguments in src/transformers/training_args.py
385
+ # or by passing the --help flag to this script.
386
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
387
+
388
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
389
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
390
+ # If we pass only one argument to the script and it's the path to a json file,
391
+ # let's parse it to get our arguments.
392
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
393
+ else:
394
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
395
+
396
+ # Set wandb project ID before instantiating the Trainer
397
+ os.environ["WANDB_PROJECT"] = data_args.wandb_project
398
+
399
+ # Detecting last checkpoint.
400
+ last_checkpoint = None
401
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
402
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
403
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
404
+ raise ValueError(
405
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
406
+ "Use --overwrite_output_dir to overcome."
407
+ )
408
+ elif last_checkpoint is not None:
409
+ logger.info(
410
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
411
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
412
+ )
413
+
414
+ # Setup logging
415
+ logging.basicConfig(
416
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
417
+ datefmt="%m/%d/%Y %H:%M:%S",
418
+ handlers=[logging.StreamHandler(sys.stdout)],
419
+ )
420
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
421
+
422
+ # Log on each process the small summary:
423
+ logger.warning(
424
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
425
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
426
+ )
427
+ # Set the verbosity to info of the Transformers logger (on main process only):
428
+ if is_main_process(training_args.local_rank):
429
+ transformers.utils.logging.set_verbosity_info()
430
+ logger.info("Training/evaluation parameters %s", training_args)
431
+
432
+ # Set seed before initializing model.
433
+ set_seed(training_args.seed)
434
+
435
+ # load the model config (discarding optimiser and trainer attributes)
436
+ config = OmegaConf.load(model_args.config_path).model
437
+
438
+ # 4. Load dataset
439
+ raw_datasets = DatasetDict()
440
+
441
+ if training_args.do_train:
442
+ raw_datasets["train"] = load_dataset(
443
+ data_args.dataset_name,
444
+ data_args.dataset_config_name,
445
+ split=data_args.train_split_name,
446
+ cache_dir=data_args.dataset_cache_dir,
447
+ use_auth_token=True if model_args.use_auth_token else None,
448
+ )
449
+
450
+ if training_args.do_eval:
451
+ raw_datasets["eval"] = load_dataset(
452
+ data_args.dataset_name,
453
+ data_args.dataset_config_name,
454
+ split=data_args.eval_split_name,
455
+ cache_dir=data_args.dataset_cache_dir,
456
+ use_auth_token=True if model_args.use_auth_token else None,
457
+ )
458
+
459
+ if training_args.do_predict:
460
+ test_split = data_args.test_split_name.split("+")
461
+ for split in test_split:
462
+ raw_datasets[split] = load_dataset(
463
+ data_args.dataset_name,
464
+ data_args.dataset_config_name,
465
+ split=split,
466
+ cache_dir=data_args.dataset_cache_dir,
467
+ use_auth_token=True if model_args.use_auth_token else None,
468
+ )
469
+
470
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
471
+ raise ValueError(
472
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
473
+ "training, evaluation or prediction has to be done."
474
+ )
475
+
476
+ # if not training, there is no need to run multiple epochs
477
+ if not training_args.do_train:
478
+ training_args.num_train_epochs = 1
479
+
480
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
481
+ raise ValueError(
482
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
483
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
484
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
485
+ )
486
+
487
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
488
+ raise ValueError(
489
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
490
+ "Make sure to set `--text_column_name` to the correct text column - one of "
491
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
492
+ )
493
+
494
+ # 6. Resample speech dataset ALWAYS
495
+ raw_datasets = raw_datasets.cast_column(
496
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=config.sample_rate)
497
+ )
498
+
499
+ # 7. Preprocessing the datasets.
500
+ # We need to read the audio files as arrays and tokenize the targets.
501
+ max_input_length = int(data_args.max_duration_in_seconds * config.sample_rate)
502
+ min_input_length = min(int(data_args.min_duration_in_seconds * config.sample_rate), 1)
503
+ max_eval_input_length = int(data_args.max_eval_duration_in_seconds * config.sample_rate) if data_args.max_eval_duration_in_seconds else None
504
+ max_target_length = data_args.max_target_length
505
+ min_target_length = data_args.min_target_length
506
+ audio_column_name = data_args.audio_column_name
507
+ num_workers = data_args.preprocessing_num_workers
508
+ text_column_name = data_args.text_column_name
509
+ do_lower_case = data_args.do_lower_case
510
+ dataset_name = data_args.dataset_name
511
+
512
+ # Define tokens to ignore/replace
513
+ tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
514
+ gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
515
+ gigaspeech_disfluencies = ["<other>", "<sil>"]
516
+ swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
517
+ "[vocalized-noise]", "_1"]
518
+ swb_punctuations = ["{", "}", "[", "]-", "]"]
519
+ earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>"]
520
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
521
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>",
522
+ "<sil>", ""]
523
+
524
+ if training_args.do_train and data_args.max_train_samples is not None:
525
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
526
+
527
+ if training_args.do_eval and data_args.max_eval_samples is not None:
528
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
529
+
530
+ if training_args.do_predict and data_args.max_predict_samples is not None:
531
+ for split in test_split:
532
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_predict_samples))
533
+
534
+ # filter data where the targets are ignored in scoring
535
+ def is_target_labels(input_str):
536
+ return input_str.lower() not in ignore_segments
537
+
538
+ raw_datasets = raw_datasets.filter(
539
+ is_target_labels,
540
+ num_proc=num_workers,
541
+ input_columns=[text_column_name],
542
+ desc="filtering data where the targets are ignored in scoring",
543
+ )
544
+
545
+ def prepare_dataset(batch):
546
+ # pre-process audio
547
+ try:
548
+ sample = batch[audio_column_name]
549
+ except ValueError:
550
+ # E22: some samples are empty (no audio). Reading the empty audio array will trigger
551
+ # a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
552
+ # They will be filtered in the subsequent filtering stage and so are
553
+ # explicitly ignored during training.
554
+ sample = {"array": np.array([0.]), "sampling_rate": config.sampling_rate}
555
+
556
+ # NeMo RNNT model performs the audio preprocessing in the `.forward()` call
557
+ # => we only need to supply it with the raw audio values
558
+ batch["input_ids"] = sample["array"]
559
+ batch["input_lengths"] = len(sample["array"])
560
+
561
+ # 'Error correction' of targets
562
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
563
+
564
+ # LibriSpeech ASR
565
+ if dataset_name == "librispeech_asr":
566
+ pass # no error correction necessary
567
+
568
+ # VoxPopuli
569
+ if dataset_name == "google/xtreme_s":
570
+ pass # no error correction necessary
571
+
572
+ # Common Voice 9
573
+ if dataset_name == "mozilla-foundation/common_voice_9_0":
574
+ if input_str.startswith('"') and input_str.endswith('"'):
575
+ # we can remove trailing quotation marks as they do not affect the transcription
576
+ input_str = input_str[1:-1]
577
+ # replace double quotation marks with single
578
+ input_str = input_str.replace('""', '"')
579
+
580
+ # TED-LIUM (Release 3)
581
+ if dataset_name == "LIUM/tedlium":
582
+ # delete the <unk> token from the text
583
+ input_str = input_str.replace("<unk>", "")
584
+ # replace spaced apostrophes with un-spaced (it 's -> it's)
585
+ for contraction in tedlium_contractions:
586
+ input_str = input_str.replace(contraction, contraction[1:])
587
+
588
+ # GigaSpeech
589
+ if dataset_name == "speechcolab/gigaspeech":
590
+ for disfluency in gigaspeech_disfluencies:
591
+ input_str = input_str.replace(disfluency, "")
592
+ # convert spelled out punctuation to symbolic form
593
+ for punctuation, replacement in gigaspeech_punctuation.items():
594
+ input_str = input_str.replace(punctuation, replacement)
595
+
596
+ # SWB: hide the path to the private HF dataset
597
+ if "switchboard" in dataset_name:
598
+ for disfluency in swb_disfluencies:
599
+ input_str = input_str.replace(disfluency, "")
600
+ # remove parenthesised text (test data only)
601
+ input_str = re.sub("[\(].*?[\)]", "", input_str)
602
+ for punctuation in swb_punctuations:
603
+ input_str = input_str.replace(punctuation, "")
604
+ # replace anomalous words with their correct transcriptions
605
+ split_str = input_str.split("/")
606
+ if len(split_str) > 1:
607
+ input_str = " ".join(
608
+ [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
609
+
610
+ # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
611
+ if "earnings22" in dataset_name:
612
+ for disfluency in earnings_disfluencies:
613
+ input_str = input_str.replace(disfluency, "")
614
+
615
+ # SPGISpeech
616
+ if dataset_name == "kensho/spgispeech":
617
+ pass # no error correction necessary
618
+
619
+ # JIWER compliance (for WER/CER calc.)
620
+ # remove multiple spaces
621
+ input_str = re.sub(r"\s\s+", " ", input_str)
622
+ # strip trailing spaces
623
+ input_str = input_str.strip()
624
+
625
+ # We can't currently tokenize the dataset... we need the pre-processed text data in order to
626
+ # build our SPE tokenizer. Once we've defined our tokenizer, we can come back and
627
+ # tokenize the text. For now, just return the pre-processed text data
628
+ batch[text_column_name] = input_str
629
+ return batch
630
+
631
+ vectorized_datasets = raw_datasets.map(
632
+ prepare_dataset,
633
+ num_proc=num_workers,
634
+ desc="preprocess train dataset",
635
+ )
636
+
637
+ # filter training data with inputs shorter than min_input_length or longer than max_input_length
638
+ def is_audio_in_length_range(length):
639
+ return length > min_input_length and length < max_input_length
640
+
641
+ if training_args.do_train:
642
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
643
+ is_audio_in_length_range,
644
+ num_proc=num_workers,
645
+ input_columns=["input_lengths"],
646
+ )
647
+
648
+ if max_eval_input_length is not None:
649
+ # filter training data with inputs longer than max_input_length
650
+ def is_eval_audio_in_length_range(length):
651
+ return min_input_length < length < max_eval_input_length
652
+
653
+ vectorized_datasets = vectorized_datasets.filter(
654
+ is_eval_audio_in_length_range,
655
+ num_proc=num_workers,
656
+ input_columns=["input_length"],
657
+ )
658
+
659
+ def is_labels_non_zero(transcription):
660
+ return len(transcription) > 0
661
+
662
+ vectorized_datasets = vectorized_datasets.filter(
663
+ is_labels_non_zero,
664
+ num_proc=num_workers,
665
+ input_columns=[text_column_name],
666
+ )
667
+
668
+ # for large datasets it is advised to run the preprocessing on a
669
+ # single machine first with `args.preprocessing_only` since there will mostly likely
670
+ # be a timeout when running the script in distributed mode.
671
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
672
+ # cached dataset
673
+ if data_args.preprocessing_only:
674
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
675
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
676
+ return
677
+
678
+ # Function to build a NeMo tokenizer manifest from a HF dataset
679
+ # TODO: with a bit of hacking around we can probably bypass this step entirely
680
+ def build_manifest(ds, manifest_path):
681
+ with open(manifest_path, 'w') as fout:
682
+ for sample in tqdm(ds[text_column_name]):
683
+ # Write the metadata to the manifest
684
+ metadata = {
685
+ "text": sample
686
+ }
687
+ json.dump(metadata, fout)
688
+ fout.write('\n')
689
+
690
+ config.train_ds = config.validation_ds = config.test_ds = None
691
+
692
+ if not os.path.exists(model_args.manifest_path) and training_args.do_train:
693
+ os.makedirs(model_args.manifest_path)
694
+ manifest = os.path.join(model_args.manifest_path, "train.json")
695
+ logger.info(f"Building training manifest at {manifest}")
696
+ build_manifest(vectorized_datasets["train"], manifest)
697
+ else:
698
+ manifest = os.path.join(model_args.manifest_path, "train.json")
699
+ logger.info(f"Re-using training manifest at {manifest}")
700
+
701
+ tokenizer_dir, tokenizer_type_cfg = build_tokenizer(model_args, data_args, manifest)
702
+
703
+ # generalise the script later to load a pre-built tokenizer for eval only
704
+ config.tokenizer.dir = tokenizer_dir
705
+ config.tokenizer.type = tokenizer_type_cfg
706
+
707
+ if model_args.add_adapter:
708
+ # Utility method to check and update the model config
709
+ def update_model_config_to_support_adapter(model_cfg):
710
+ with open_dict(model_cfg):
711
+ adapter_metadata = adapter_mixins.get_registered_adapter(model_cfg.encoder._target_)
712
+ if adapter_metadata is not None:
713
+ model_cfg.encoder._target_ = adapter_metadata.adapter_class_path
714
+
715
+ logging.info("Updated encoder _target_ model :", model_cfg.encoder._target_)
716
+ return model_cfg
717
+
718
+ config = update_model_config_to_support_adapter(config)
719
+
720
+ # possibly fused-computation of prediction net + joint net + loss + WER calculation
721
+ config.joint.fuse_loss_wer = model_args.fuse_loss_wer
722
+ if model_args.fuse_loss_wer:
723
+ config.joint.fused_batch_size = model_args.fused_batch_size
724
+
725
+ if model_args.model_name_or_path is not None:
726
+ # load pre-trained model weights
727
+ model = RNNTBPEModel.from_pretrained(model_args.model_name_or_path, override_config_path=config, map_location="cpu")
728
+ model.save_name = model_args.model_name_or_path
729
+
730
+ pretrained_decoder = model.decoder.state_dict()
731
+ pretrained_joint = model.joint.state_dict()
732
+ model.change_vocabulary(new_tokenizer_dir=tokenizer_dir, new_tokenizer_type=tokenizer_type_cfg)
733
+
734
+ # TODO: add checks for loading decoder/joint state dict
735
+ model.decoder.load_state_dict(pretrained_decoder)
736
+ model.joint.load_state_dict(pretrained_joint)
737
+
738
+ else:
739
+ model = RNNTBPEModel(cfg=config)
740
+ model.save_name = model_args.config_path.split("/")[-1].split(".")[0]
741
+ model.change_vocabulary(new_tokenizer_dir=tokenizer_dir, new_tokenizer_type=tokenizer_type_cfg)
742
+
743
+ if model_args.add_adapter:
744
+ adapter_name = model_args.config_path.split("/")[-1].split(".")[0]
745
+ adapter_dim = model.cfg.encoder.d_model
746
+ adapter_activation = "swish"
747
+ adapter_norm_position = "post"
748
+ adapter_cfg = LinearAdapterConfig(
749
+ in_features=model.cfg.encoder.d_model,
750
+ # conformer specific model dim. Every layer emits this dim at its output.
751
+ dim=adapter_dim, # the bottleneck dimension of the adapter
752
+ activation=adapter_activation, # activation used in bottleneck block
753
+ norm_position=adapter_norm_position, # whether to use LayerNorm at the beginning or the end of the adapter
754
+ )
755
+ logger.info("Adapter config: ", adapter_cfg)
756
+ model.add_adapter(name=adapter_name, cfg=adapter_cfg)
757
+ model.set_enabled_adapters(enabled=False) # disable all adapters
758
+ model.set_enabled_adapters(name=adapter_name, enabled=True) # enable only the current adapter we want to train
759
+
760
+ def enable_bn(m):
761
+ if type(m) == nn.BatchNorm1d:
762
+ m.train()
763
+ for param in m.parameters():
764
+ param.requires_grad_(True)
765
+
766
+ if model_args.freeze_encoder:
767
+ model.encoder.freeze()
768
+ model.encoder.apply(enable_bn)
769
+ logging.info("Model encoder has been frozen, and batch normalization has been unfrozen")
770
+
771
+ if model_args.add_adapter:
772
+ model.unfreeze_enabled_adapters()
773
+ logging.info("Model adapter has been unfrozen")
774
+
775
+ # now that we have our model and tokenizer defined, we can tokenize the text data
776
+ tokenizer = model.tokenizer.tokenizer.encode_as_ids
777
+
778
+ def tokenize_transcripts(batch):
779
+ batch["labels"] = tokenizer(batch[text_column_name])
780
+ return batch
781
+
782
+ vectorized_datasets = vectorized_datasets.map(tokenize_transcripts, num_proc=num_workers,
783
+ desc="Tokenizing datasets...",
784
+ remove_columns=next(iter(raw_datasets.values())).column_names)
785
+
786
+ def compute_metrics(pred):
787
+ # Tuple of WERs returned by the model during eval: (wer, wer_num, wer_denom)
788
+ wer_num = pred.predictions[1]
789
+ wer_denom = pred.predictions[2]
790
+ # compute WERs over concat batches
791
+ wer = sum(wer_num) / sum(wer_denom)
792
+ return {"wer": wer}
793
+
794
+ class UnfreezeEncoderCallback(TrainerCallback):
795
+ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
796
+ model.encoder.unfreeze()
797
+ print("Model encoder has been unfrozen")
798
+
799
+ class NeMoTrainer(Trainer):
800
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
801
+ # If we are executing this function, we are the process zero, so we don't check for that.
802
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
803
+ os.makedirs(output_dir, exist_ok=True)
804
+ logger.info(f"Saving model checkpoint to {output_dir}")
805
+ # Save a trained model and configuration using `save_pretrained()`.
806
+ # They can then be reloaded using `from_pretrained()`
807
+ self.model.save_to(save_path=os.path.join(output_dir, model.save_name + ".nemo"))
808
+ # Good practice: save your training arguments together with the trained model
809
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
810
+
811
+ # Initialize Trainer
812
+ trainer = NeMoTrainer(
813
+ model=model,
814
+ args=training_args,
815
+ compute_metrics=compute_metrics,
816
+ train_dataset=vectorized_datasets['train'] if training_args.do_train else None,
817
+ eval_dataset=vectorized_datasets['eval'] if training_args.do_eval else None,
818
+ data_collator=NeMoDataCollator,
819
+ callbacks=[UnfreezeEncoderCallback] if model_args.unfreeze_encoder else None,
820
+ )
821
+
822
+ # 8. Finally, we can start training
823
+
824
+ # Training
825
+ if training_args.do_train:
826
+
827
+ # use last checkpoint if exist
828
+ if last_checkpoint is not None:
829
+ checkpoint = last_checkpoint
830
+ elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
831
+ checkpoint = model_args.model_name_or_path
832
+ else:
833
+ checkpoint = None
834
+
835
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
836
+ trainer.save_model()
837
+
838
+ metrics = train_result.metrics
839
+ max_train_samples = (
840
+ data_args.max_train_samples
841
+ if data_args.max_train_samples is not None
842
+ else len(vectorized_datasets["train"])
843
+ )
844
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
845
+
846
+ trainer.log_metrics("train", metrics)
847
+ trainer.save_metrics("train", metrics)
848
+ trainer.save_state()
849
+
850
+ # Change decoding strategy for final eval/predict
851
+ if training_args.do_eval or training_args.do_predict:
852
+ # set beam search decoding config
853
+ beam_decoding_config = copy.deepcopy(trainer.model.cfg.decoding)
854
+ beam_decoding_config.strategy = model_args.final_decoding_strategy
855
+ beam_decoding_config.beam.beam_size = model_args.final_num_beams
856
+
857
+ trainer.model.change_decoding_strategy(beam_decoding_config)
858
+
859
+ results = {}
860
+ if training_args.do_eval:
861
+ logger.info(f"*** Running Final Evaluation ({model_args.final_decoding_strategy}) ***")
862
+
863
+ metrics = trainer.evaluate()
864
+ max_eval_samples = (
865
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
866
+ )
867
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
868
+
869
+ trainer.log_metrics("eval", metrics)
870
+ trainer.save_metrics("eval", metrics)
871
+
872
+ if training_args.do_predict:
873
+ logger.info(f"*** Running Final Prediction ({model_args.final_decoding_strategy}) ***")
874
+
875
+ for split in test_split:
876
+ predict_results = trainer.predict(
877
+ vectorized_datasets[split], metric_key_prefix=split, )
878
+ metrics = predict_results.metrics
879
+ max_predict_samples = (
880
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(vectorized_datasets[split])
881
+ )
882
+ metrics[f"{split}_samples"] = min(max_predict_samples, len(vectorized_datasets[split]))
883
+
884
+ trainer.log_metrics(split, metrics)
885
+ trainer.save_metrics(split, metrics)
886
+
887
+ if "wandb" in training_args.report_to:
888
+ import wandb
889
+ metrics = {os.path.join(split, k[len(split)+1:]): v for k, v in metrics.items()}
890
+ wandb.log(metrics)
891
+
892
+ # re-evaluate on the test set, this time computing the CER
893
+ # this is pretty wasteful to run eval twice, but very fast to implement
894
+ trainer.model.wer.use_cer = True
895
+ trainer.model.change_decoding_strategy(trainer.model.cfg.decoding)
896
+
897
+ for split in test_split:
898
+ predict_results = trainer.predict(
899
+ vectorized_datasets[split], metric_key_prefix=split, )
900
+ metrics = predict_results.metrics
901
+ # the returned metric is the CER, but under an erroneous key; we swap them here
902
+ metrics = {f"{split}_cer": metrics[f"{split}_wer"]}
903
+
904
+ trainer.log_metrics(split, metrics)
905
+ trainer.save_metrics(split, metrics)
906
+
907
+ if "wandb" in training_args.report_to:
908
+ metrics = {os.path.join(split, k[len(split) + 1:]): v for k, v in metrics.items()}
909
+ wandb.log(metrics)
910
+
911
+ # Write model card and (optionally) push to hub
912
+ config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
913
+ kwargs = {
914
+ "finetuned_from": model_args.model_name_or_path,
915
+ "tasks": "speech-recognition",
916
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
917
+ "dataset_args": (
918
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
919
+ f" {data_args.eval_split_name}"
920
+ ),
921
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
922
+ }
923
+ if "common_voice" in data_args.dataset_name:
924
+ kwargs["language"] = config_name
925
+
926
+ if training_args.push_to_hub:
927
+ trainer.push_to_hub(**kwargs)
928
+ #else:
929
+ #trainer.create_model_card(**kwargs)
930
+
931
+ return results
932
+
933
+
934
+ if __name__ == "__main__":
935
+ main()
scripts/run_batch_size_sweep.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ command:
2
+ - python3
3
+ - ${program}
4
+ - --use_auth_token
5
+ - --do_eval
6
+ - --group_by_length
7
+ - --overwrite_output_dir
8
+ - --fp16
9
+ - --do_lower_case
10
+ - --do_eval
11
+ - --do_train
12
+ - --fuse_loss_wer
13
+ - ${args}
14
+ method: grid
15
+ metric:
16
+ goal: minimize
17
+ name: train/train_loss
18
+ parameters:
19
+ config_path:
20
+ value: conf/conformer_transducer_bpe_xlarge.yaml
21
+ dataset_config_name:
22
+ value: clean
23
+ dataset_name:
24
+ value: librispeech_asr
25
+ max_steps:
26
+ value: 50
27
+ model_name_or_path:
28
+ value: stt_en_conformer_transducer_xlarge
29
+ output_dir:
30
+ value: ./sweep_output_dir
31
+ gradient_accumulation_steps:
32
+ values:
33
+ - 1
34
+ - 2
35
+ per_device_train_batch_size:
36
+ values:
37
+ - 8
38
+ - 16
39
+ fused_batch_size:
40
+ values:
41
+ - 4
42
+ - 8
43
+ - 16
44
+ per_device_eval_batch_size:
45
+ value: 4
46
+ preprocessing_num_workers:
47
+ value: 1
48
+ train_split_name:
49
+ value: train.100[:500]
50
+ eval_split_name:
51
+ value: validation[:100]
52
+ tokenizer_path:
53
+ value: tokenizer
54
+ vocab_size:
55
+ value: 1024
56
+ wandb_project:
57
+ value: rnnt-debug
58
+ logging_steps:
59
+ value: 5
60
+ program: run_speech_recognition_rnnt.py
61
+ project: rnnt-debug
scripts/run_common_voice_9.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ CUDA_VISIBLE_DEVICES=1 python run_speech_recognition_rnnt.py \
3
+ --config_path="conf/conformer_transducer_bpe_xlarge.yaml" \
4
+ --model_name_or_path="stt_en_conformer_transducer_xlarge" \
5
+ --dataset_name="mozilla-foundation/common_voice_9_0" \
6
+ --tokenizer_path="tokenizer" \
7
+ --vocab_size="1024" \
8
+ --num_train_epochs="0.90" \
9
+ --dataset_config_name="en" \
10
+ --train_split_name="train" \
11
+ --eval_split_name="validation" \
12
+ --test_split_name="test" \
13
+ --text_column_name="sentence" \
14
+ --output_dir="./conformer-transducer-xl-cv9" \
15
+ --run_name="rnnt-cv9-baseline" \
16
+ --wandb_project="rnnt" \
17
+ --per_device_train_batch_size="8" \
18
+ --per_device_eval_batch_size="4" \
19
+ --logging_steps="50" \
20
+ --learning_rate="1e-4" \
21
+ --warmup_steps="500" \
22
+ --save_strategy="steps" \
23
+ --save_steps="20000" \
24
+ --evaluation_strategy="steps" \
25
+ --eval_steps="20000" \
26
+ --report_to="wandb" \
27
+ --preprocessing_num_workers="4" \
28
+ --fused_batch_size="8" \
29
+ --length_column_name="input_lengths" \
30
+ --do_lower_case="False" \
31
+ --fuse_loss_wer \
32
+ --group_by_length \
33
+ --overwrite_output_dir \
34
+ --do_train \
35
+ --do_eval \
36
+ --do_predict \
37
+ --push_to_hub \
38
+ --use_auth_token