Commit
•
6652985
1
Parent(s):
7350459
Update README.md
Browse files
README.md
CHANGED
@@ -56,19 +56,35 @@ This code snippet shows how to evaluate **Wav2Vec2-Large-Tedlium** on the TEDLIU
|
|
56 |
|
57 |
```python
|
58 |
from datasets import load_dataset
|
59 |
-
from transformers import
|
60 |
import torch
|
61 |
from jiwer import wer
|
|
|
62 |
tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def map_to_pred(batch):
|
66 |
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
|
67 |
with torch.no_grad():
|
68 |
-
|
69 |
-
|
70 |
-
transcription =
|
71 |
-
batch["transcription"] = transcription
|
72 |
return batch
|
|
|
73 |
result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
|
74 |
-
print("WER:", wer(result["text"], result["transcription"]))
|
|
|
|
56 |
|
57 |
```python
|
58 |
from datasets import load_dataset
|
59 |
+
from transformers import AutoProcessor, SpeechEncoderDecoderModel
|
60 |
import torch
|
61 |
from jiwer import wer
|
62 |
+
|
63 |
tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
|
64 |
+
|
65 |
+
def filter_ds(text):
|
66 |
+
return text != "ignore_time_segment_in_scoring"
|
67 |
+
|
68 |
+
# remove samples ignored from scoring
|
69 |
+
tedlium_eval = tedlium_eval.map(filter_ds, input_columns=["text"])
|
70 |
+
|
71 |
+
model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium").to("cuda")
|
72 |
+
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
|
73 |
+
|
74 |
+
gen_kwargs = {
|
75 |
+
"max_length": 200,
|
76 |
+
"num_beams": 5,
|
77 |
+
"length_penalty": 1.2
|
78 |
+
}
|
79 |
+
|
80 |
def map_to_pred(batch):
|
81 |
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
|
82 |
with torch.no_grad():
|
83 |
+
generated = model.generate(input_values.to("cuda"), **gen_kwargs)
|
84 |
+
decoded = processor.batch_decode(generated, skip_special_tokens=True)
|
85 |
+
batch["transcription"] = decoded[0]
|
|
|
86 |
return batch
|
87 |
+
|
88 |
result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
|
89 |
+
print("WER:", wer(result["text"], result["transcription"]))
|
90 |
+
```
|