|
import torch |
|
import librosa |
|
from datasets import load_dataset, Audio |
|
from transformers import WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration |
|
from huggingface_hub import login |
|
import argparse |
|
from evaluate import load |
|
|
|
my_parser = argparse.ArgumentParser() |
|
|
|
|
|
my_parser.add_argument("--model_name", "-model_name", type=str, action="store", default = "openai/whisper-tiny") |
|
my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store") |
|
my_parser.add_argument("--dataset_name", "-dataset_name", type=str, action="store", default = "google/fleurs") |
|
my_parser.add_argument("--split", "-split", type=str, action="store", default = "test") |
|
my_parser.add_argument("--subset", "-subset", type=str, action="store") |
|
|
|
args = my_parser.parse_args() |
|
try: |
|
login(args.hf_token) |
|
except: |
|
raise(f"Can't login please set --hf_token {args.hf_token}") |
|
|
|
|
|
dataset_name = args.dataset_name |
|
model_name = args.model_name |
|
subset = args.subset |
|
text_column = "sentence" |
|
if dataset_name == "google/fleurs": |
|
text_column = "transcription" |
|
|
|
print(f"Evaluating {args.model_name} on {args.dataset_name} [{subset}]") |
|
|
|
|
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
|
|
test_dataset = load_dataset(dataset_name, subset, split=args.split, use_auth_token=True) |
|
processor = WhisperProcessor.from_pretrained(model_name, language="Arabic", task="transcribe") |
|
tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Arabic", task="transcribe") |
|
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
|
def prepare_dataset(batch): |
|
|
|
audio = batch["audio"] |
|
|
|
|
|
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] |
|
|
|
|
|
batch["labels"] = tokenizer(batch[text_column]).input_ids |
|
return batch |
|
|
|
test_dataset = test_dataset.map(prepare_dataset) |
|
|
|
model = model.to("cuda") |
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ar", task = "transcribe") |
|
|
|
def map_to_result(batch): |
|
|
|
with torch.no_grad(): |
|
input_values = torch.tensor(batch["input_features"], device="cuda").unsqueeze(0) |
|
pred_ids = model.generate(input_values) |
|
|
|
batch["pred_str"] = processor.batch_decode(pred_ids, skip_special_tokens = True)[0] |
|
batch["text"] = processor.decode(batch["labels"], skip_special_tokens = True) |
|
|
|
return batch |
|
results = test_dataset.map(map_to_result) |
|
|
|
wer = load("wer") |
|
print("Test WER: {:.3f}".format(wer.compute(predictions=results["pred_str"], references=results["text"]))) |
|
|