File size: 1,977 Bytes
c3b1078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import json
import os

from datasets import load_dataset
from evaluate import load
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer

from faster_whisper import WhisperModel

parser = argparse.ArgumentParser(description="WER benchmark")
parser.add_argument(
    "--audio_numb",
    type=int,
    default=None,
    help="Specify the number of validation audio files in the dataset."
    " Set to None to retrieve all audio files.",
)
args = parser.parse_args()

model_path = "large-v3"
model = WhisperModel(model_path, device="cuda")

# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)

# define the evaluation metric
wer_metric = load("wer")

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
    normalizer = EnglishTextNormalizer(json.load(f))


def inference(batch):
    batch["transcription"] = []
    for sample in batch["audio"]:
        segments, info = model.transcribe(sample["array"], language="en")
        batch["transcription"].append("".join([segment.text for segment in segments]))
    batch["reference"] = batch["text"]
    return batch


dataset = dataset.map(function=inference, batched=True, batch_size=16)

all_transcriptions = []
all_references = []

# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
    all_transcriptions.append(result["transcription"])
    all_references.append(result["reference"])
    if args.audio_numb and i == (args.audio_numb - 1):
        break

# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(
    predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)