maxidl
commited on
Commit
·
45a4f3a
1
Parent(s):
4246f93
add chunked wer to eval script
Browse files
README.md
CHANGED
@@ -114,8 +114,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
114 |
"""
|
115 |
Evaluation on the full test set:
|
116 |
- takes ~20mins (RTX 3090).
|
117 |
-
- requires ~170GB RAM to compute the WER.
|
118 |
-
See https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5 on how to implement this.
|
119 |
"""
|
120 |
test_dataset = load_dataset("common_voice", "de", split="test") # use "test[:1%]" for 1% sample
|
121 |
wer = load_metric("wer")
|
@@ -151,8 +150,30 @@ def evaluate(batch):
|
|
151 |
|
152 |
result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory
|
153 |
|
154 |
-
|
|
|
155 |
# WER: 12.615308
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
```
|
157 |
|
158 |
**Test Result**: 12.62 %
|
|
|
114 |
"""
|
115 |
Evaluation on the full test set:
|
116 |
- takes ~20mins (RTX 3090).
|
117 |
+
- requires ~170GB RAM to compute the WER. Below, we use a chunked implementation of WER to avoid large RAM consumption.
|
|
|
118 |
"""
|
119 |
test_dataset = load_dataset("common_voice", "de", split="test") # use "test[:1%]" for 1% sample
|
120 |
wer = load_metric("wer")
|
|
|
150 |
|
151 |
result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory
|
152 |
|
153 |
+
# non-chunked version:
|
154 |
+
# print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
155 |
# WER: 12.615308
|
156 |
+
|
157 |
+
# Chunked version, see https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5:
|
158 |
+
import jiwer
|
159 |
+
|
160 |
+
def chunked_wer(targets, predictions, chunk_size=None):
|
161 |
+
if chunk_size is None: return jiwer.wer(targets, predictions)
|
162 |
+
start = 0
|
163 |
+
end = chunk_size
|
164 |
+
H, S, D, I = 0, 0, 0, 0
|
165 |
+
while start < len(targets):
|
166 |
+
chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end])
|
167 |
+
H = H + chunk_metrics["hits"]
|
168 |
+
S = S + chunk_metrics["substitutions"]
|
169 |
+
D = D + chunk_metrics["deletions"]
|
170 |
+
I = I + chunk_metrics["insertions"]
|
171 |
+
start += chunk_size
|
172 |
+
end += chunk_size
|
173 |
+
return float(S + D + I) / float(H + S + D)
|
174 |
+
|
175 |
+
print("Total (chunk_size=1000), WER: {:2f}".format(100 * chunked_wer(result["pred_strings"], result["sentence"], chunk_size=1000)))
|
176 |
+
# Total (chunk=1000), WER: 12.768981
|
177 |
```
|
178 |
|
179 |
**Test Result**: 12.62 %
|