#!/usr/bin/env python3 import sys import torch from transformers import AutoModelForCTC, AutoProcessor from datasets import load_dataset, load_metric import torchaudio.functional as F device = "cuda" if torch.cuda.is_available() else "cpu" model_id = sys.argv[1] lang = sys.argv[2] lang_phoneme = sys.argv[3] num_samples = int(sys.argv[4]) model = AutoModelForCTC.from_pretrained(model_id).to(device) processor = AutoProcessor.from_pretrained(model_id) ds = load_dataset("common_voice", lang, split="test", streaming=True) sample_iter = iter(ds) wer = load_metric("wer") cer = load_metric("cer") targets_ids = [] predictions_ids = [] for i in range(num_samples): sample = next(sample_iter) resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy() input_values = processor(resampled_audio, return_tensors="pt").input_values with torch.no_grad(): logits = model(input_values.to(device)).logits prediction_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(prediction_ids) print(f"Correct: {sample['sentence']}") print(f"Predict: {transcription}") print(20 * '-') predictions_ids.append(prediction_ids[0].tolist()) kwargs = {} if len(lang_phoneme) > 0: kwargs["phonemizer_lang"] = lang_phoneme targets_ids.append(processor.tokenizer(sample["sentence"], **kwargs).input_ids) print("Compute metrics.....") import ipdb; ipdb.set_trace() transcriptions = processor.batch_decode(predictions_ids) targets_str = processor.batch_decode(targets_ids, group_tokens=False) wer = wer.compute(predictions=transcriptions, references=targets_str) cer = cer.compute(predictions=transcriptions, references=targets_str) print("wer", wer) print("cer", cer)