elgeish commited on
Commit
e904c32
1 Parent(s): 32f1217

use torchaudio (faster than librosa)

Browse files
Files changed (1) hide show
  1. README.md +20 -9
README.md CHANGED
@@ -24,7 +24,7 @@ model-index:
24
  metrics:
25
  - name: Test WER
26
  type: wer
27
- value: 26.60
28
  ---
29
 
30
  # Wav2Vec2-Large-XLSR-53-Arabic
@@ -39,22 +39,27 @@ When using this model, make sure that your speech input is sampled at 16kHz.
39
  The model can be used directly (without a language model) as follows:
40
 
41
  ```python
42
- import librosa
43
  import torch
 
44
  from datasets import load_dataset
45
  from lang_trans.arabic import buckwalter
46
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
47
 
48
  dataset = load_dataset("common_voice", "ar", split="test[:10]")
49
- processor = Wav2Vec2Processor.from_pretrained("elgeish/wav2vec2-large-xlsr-53-arabic")
50
- model = Wav2Vec2ForCTC.from_pretrained("elgeish/wav2vec2-large-xlsr-53-arabic")
51
- model.eval()
 
 
52
 
53
  def prepare_example(example):
54
- example["speech"], _ = librosa.load(example["path"], sr=16000)
 
55
  return example
56
 
57
  dataset = dataset.map(prepare_example)
 
 
58
 
59
  def predict(batch):
60
  inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
@@ -113,17 +118,23 @@ The model can be evaluated as follows on the Arabic test data of Common Voice:
113
 
114
  ```python
115
  import jiwer
116
- import librosa
117
  import torch
 
118
  from datasets import load_dataset
119
  from lang_trans.arabic import buckwalter
120
  from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
121
 
122
  set_seed(42)
123
  test_split = load_dataset("common_voice", "ar", split="test")
 
 
 
 
 
124
 
125
  def prepare_example(example):
126
- example["speech"], _ = librosa.load(example["path"], sr=16000)
 
127
  return example
128
 
129
  test_split = test_split.map(prepare_example)
@@ -159,7 +170,7 @@ metrics = jiwer.compute_measures(
159
  print(f"WER: {metrics['wer']:.2%}")
160
  ```
161
 
162
- **Test Result**: 26.60%
163
 
164
  ## Training
165
 
 
24
  metrics:
25
  - name: Test WER
26
  type: wer
27
+ value: 26.55
28
  ---
29
 
30
  # Wav2Vec2-Large-XLSR-53-Arabic
 
39
  The model can be used directly (without a language model) as follows:
40
 
41
  ```python
 
42
  import torch
43
+ import torchaudio
44
  from datasets import load_dataset
45
  from lang_trans.arabic import buckwalter
46
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
47
 
48
  dataset = load_dataset("common_voice", "ar", split="test[:10]")
49
+ resamplers = { # all three sampling rates exist in test split
50
+ 48000: torchaudio.transforms.Resample(48000, 16000),
51
+ 44100: torchaudio.transforms.Resample(44100, 16000),
52
+ 32000: torchaudio.transforms.Resample(32000, 16000),
53
+ }
54
 
55
  def prepare_example(example):
56
+ speech, sampling_rate = torchaudio.load(example["path"])
57
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
58
  return example
59
 
60
  dataset = dataset.map(prepare_example)
61
+ processor = Wav2Vec2Processor.from_pretrained("elgeish/wav2vec2-large-xlsr-53-arabic")
62
+ model = Wav2Vec2ForCTC.from_pretrained("elgeish/wav2vec2-large-xlsr-53-arabic").eval()
63
 
64
  def predict(batch):
65
  inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
 
118
 
119
  ```python
120
  import jiwer
 
121
  import torch
122
+ import torchaudio
123
  from datasets import load_dataset
124
  from lang_trans.arabic import buckwalter
125
  from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
126
 
127
  set_seed(42)
128
  test_split = load_dataset("common_voice", "ar", split="test")
129
+ resamplers = { # all three sampling rates exist in test split
130
+ 48000: torchaudio.transforms.Resample(48000, 16000),
131
+ 44100: torchaudio.transforms.Resample(44100, 16000),
132
+ 32000: torchaudio.transforms.Resample(32000, 16000),
133
+ }
134
 
135
  def prepare_example(example):
136
+ speech, sampling_rate = torchaudio.load(example["path"])
137
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
138
  return example
139
 
140
  test_split = test_split.map(prepare_example)
 
170
  print(f"WER: {metrics['wer']:.2%}")
171
  ```
172
 
173
+ **Test Result**: 26.55%
174
 
175
  ## Training
176