File size: 1,252 Bytes
def3e88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from typing import List
import torch
import argparse
import shutil
import tempfile
from speechbrain.pretrained import EncoderDecoderASR
def asr_model_inference(model: EncoderDecoderASR, audios: List[str]) -> List[str]:
"""
convert input audio to words and return the result
"""
tmp_dir = tempfile.mkdtemp()
results = [process_audio(model, audio, tmp_dir) for audio in audios]
shutil.rmtree(tmp_dir)
return results
def process_audio(model: EncoderDecoderASR, audio: str, savedir:str) -> str:
"""
convert input audio to words and return the result
"""
waveform = model.load_audio(audio, savedir=savedir)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words, predicted_tokens = model.transcribe_batch(
batch, rel_length
)
return predicted_words[0]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-I", dest="audio_file", required=True)
args = parser.parse_args()
asr_model = EncoderDecoderASR.from_hparams(
source="./inference", hparams_file="hyperparams.yaml", savedir="inference", run_opts={"device": "cpu"})
print(asr_model_inference(asr_model, [args.audio_file])) |