File size: 2,875 Bytes
9cc251d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import librosa
from datasets import load_dataset, Audio
from transformers import WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration
from huggingface_hub import login
import argparse
from evaluate import load

my_parser = argparse.ArgumentParser()
# my_parser.add_argument("--pal", "-paths_as_labels", action="store_true")

my_parser.add_argument("--model_name", "-model_name", type=str, action="store", default = "openai/whisper-tiny")
my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store")
my_parser.add_argument("--dataset_name", "-dataset_name", type=str, action="store", default = "google/fleurs")
my_parser.add_argument("--split", "-split", type=str, action="store", default = "test")
my_parser.add_argument("--subset", "-subset", type=str, action="store")

args = my_parser.parse_args()
try:
  login(args.hf_token)
except:
  raise(f"Can't login please set --hf_token {args.hf_token}")


dataset_name = args.dataset_name 
model_name = args.model_name
subset = args.subset
text_column = "sentence"
if dataset_name == "google/fleurs":
  text_column = "transcription"
  
print(f"Evaluating {args.model_name} on {args.dataset_name} [{subset}]")


feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

test_dataset = load_dataset(dataset_name, subset, split=args.split, use_auth_token=True)
processor = WhisperProcessor.from_pretrained(model_name, language="Arabic", task="transcribe")
tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Arabic", task="transcribe")
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16000))

# Preprocessing the datasets.
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch[text_column]).input_ids
    return batch

test_dataset = test_dataset.map(prepare_dataset)

model = model.to("cuda")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ar", task = "transcribe")

def map_to_result(batch):

  with torch.no_grad():
    input_values = torch.tensor(batch["input_features"], device="cuda").unsqueeze(0)
    pred_ids = model.generate(input_values)

  batch["pred_str"] = processor.batch_decode(pred_ids, skip_special_tokens = True)[0]
  batch["text"] = processor.decode(batch["labels"], skip_special_tokens = True)
  
  return batch
results = test_dataset.map(map_to_result)

wer = load("wer")
print("Test WER: {:.3f}".format(wer.compute(predictions=results["pred_str"], references=results["text"])))