File size: 1,252 Bytes
def3e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import List
import torch
import argparse
import shutil
import tempfile
from speechbrain.pretrained import EncoderDecoderASR


def asr_model_inference(model: EncoderDecoderASR, audios: List[str]) -> List[str]:
    """
    convert input audio to words and return the result
    """
    tmp_dir = tempfile.mkdtemp()
    results = [process_audio(model, audio, tmp_dir) for audio in audios]
    shutil.rmtree(tmp_dir)
    return results

def process_audio(model: EncoderDecoderASR, audio: str, savedir:str) -> str:
    """
    convert input audio to words and return the result
    """
    waveform = model.load_audio(audio, savedir=savedir)
    # Fake a batch:
    batch = waveform.unsqueeze(0)
    rel_length = torch.tensor([1.0])
    predicted_words, predicted_tokens = model.transcribe_batch(
        batch, rel_length
    )
    return predicted_words[0]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-I", dest="audio_file", required=True)

    args = parser.parse_args()

    asr_model = EncoderDecoderASR.from_hparams(
        source="./inference", hparams_file="hyperparams.yaml", savedir="inference", run_opts={"device": "cpu"})

    print(asr_model_inference(asr_model, [args.audio_file]))