File size: 2,360 Bytes
5f72cc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import csv
import torch
import torchaudio
import numpy as np
import evaluate
from transformers import HubertForCTC, Wav2Vec2Processor
batch_size = 8
device = "cuda:0" # or cpu
torch_dtype = torch.float16
sampling_rate = 16_000
model_name = "/home/yehor/ext-ml-disk/asr/hubert-training/models/final-85500"
testset_file = "/home/yehor/ext-ml-disk/asr/w2v2-bert-training/eval/rows_no_defis.csv"
# Load the test dataset
with open(testset_file) as f:
samples = list(csv.DictReader(f))
# Load the model
asr_model = HubertForCTC.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch_dtype,
# attn_implementation="flash_attention_2",
)
processor = Wav2Vec2Processor.from_pretrained(model_name)
# A util function to make batches
def make_batches(iterable, n=1):
lx = len(iterable)
for ndx in range(0, lx, n):
yield iterable[ndx : min(ndx + n, lx)]
# Temporary variables
predictions_all = []
references_all = []
# Inference in the batched mode
for batch in make_batches(samples, batch_size):
paths = [it["path"] for it in batch]
references = [it["text"] for it in batch]
# Extract audio
audio_inputs = []
for path in paths:
audio_input, sampling_rate = torchaudio.load(path, backend="sox")
audio_input = audio_input.squeeze(0).numpy()
audio_inputs.append(audio_input)
# Transcribe the audio
inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
# Log outputs
print("---")
print("Predictions:")
print(predictions)
print("References:")
print(references)
print("---")
# Add predictions and references
predictions_all.extend(predictions)
references_all.extend(references)
# Load evaluators
wer = evaluate.load("wer")
cer = evaluate.load("cer")
# Evaluate
wer_value = round(
wer.compute(predictions=predictions_all, references=references_all), 4
)
cer_value = round(
cer.compute(predictions=predictions_all, references=references_all), 4
)
# Print results
print("Final:")
print(f"WER: {wer_value} | CER: {cer_value}")
|