hubert-uk / run_demo.py
Yehor's picture
Create run_demo.py
5f72cc6 verified
raw
history blame
2.36 kB
import csv
import torch
import torchaudio
import numpy as np
import evaluate
from transformers import HubertForCTC, Wav2Vec2Processor
batch_size = 8
device = "cuda:0" # or cpu
torch_dtype = torch.float16
sampling_rate = 16_000
model_name = "/home/yehor/ext-ml-disk/asr/hubert-training/models/final-85500"
testset_file = "/home/yehor/ext-ml-disk/asr/w2v2-bert-training/eval/rows_no_defis.csv"
# Load the test dataset
with open(testset_file) as f:
samples = list(csv.DictReader(f))
# Load the model
asr_model = HubertForCTC.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch_dtype,
# attn_implementation="flash_attention_2",
)
processor = Wav2Vec2Processor.from_pretrained(model_name)
# A util function to make batches
def make_batches(iterable, n=1):
lx = len(iterable)
for ndx in range(0, lx, n):
yield iterable[ndx : min(ndx + n, lx)]
# Temporary variables
predictions_all = []
references_all = []
# Inference in the batched mode
for batch in make_batches(samples, batch_size):
paths = [it["path"] for it in batch]
references = [it["text"] for it in batch]
# Extract audio
audio_inputs = []
for path in paths:
audio_input, sampling_rate = torchaudio.load(path, backend="sox")
audio_input = audio_input.squeeze(0).numpy()
audio_inputs.append(audio_input)
# Transcribe the audio
inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
# Log outputs
print("---")
print("Predictions:")
print(predictions)
print("References:")
print(references)
print("---")
# Add predictions and references
predictions_all.extend(predictions)
references_all.extend(references)
# Load evaluators
wer = evaluate.load("wer")
cer = evaluate.load("cer")
# Evaluate
wer_value = round(
wer.compute(predictions=predictions_all, references=references_all), 4
)
cer_value = round(
cer.compute(predictions=predictions_all, references=references_all), 4
)
# Print results
print("Final:")
print(f"WER: {wer_value} | CER: {cer_value}")