GreenRaptor commited on
Commit
468ac2e
·
1 Parent(s): 9e43625

Create mms_infer.py

Browse files
Files changed (1) hide show
  1. mms_infer.py +52 -0
mms_infer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import soundfile as sf
9
+ import tempfile
10
+ from pathlib import Path
11
+ import os
12
+ import subprocess
13
+ import sys
14
+ import re
15
+
16
+ def parser():
17
+ parser = argparse.ArgumentParser(description="ASR inference script for MMS model")
18
+ parser.add_argument("--model", type=str, help="path to ASR model", required=True)
19
+ parser.add_argument("--audio", type=str, help="path to audio file", required=True, nargs='+')
20
+ parser.add_argument("--lang", type=str, help="audio language", required=True)
21
+ parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter")
22
+ return parser.parse_args()
23
+
24
+ def process(args):
25
+ with tempfile.TemporaryDirectory() as tmpdir:
26
+ print(">>> preparing tmp manifest dir ...", file=sys.stderr)
27
+ tmpdir = Path(tmpdir)
28
+ with open(tmpdir / "dev.tsv", "w") as fw:
29
+ fw.write("/\n")
30
+ for audio in args.audio:
31
+ nsample = sf.SoundFile(audio).frames
32
+ fw.write(f"{audio}\t{nsample}\n")
33
+ with open(tmpdir / "dev.uid", "w") as fw:
34
+ fw.write(f"{audio}\n"*len(args.audio))
35
+ with open(tmpdir / "dev.ltr", "w") as fw:
36
+ fw.write("d u m m y | d u m m y\n"*len(args.audio))
37
+ with open(tmpdir / "dev.wrd", "w") as fw:
38
+ fw.write("dummy dummy\n"*len(args.audio))
39
+ cmd = f"""
40
+ PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir}
41
+ """
42
+ print(">>> loading model & running inference ...", file=sys.stderr)
43
+ subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
44
+ with open(tmpdir/"hypo.word") as fr:
45
+ for ii, hypo in enumerate(fr):
46
+ hypo = re.sub("\(\S+\)$", "", hypo).strip()
47
+ print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')
48
+
49
+
50
+ if __name__ == "__main__":
51
+ args = parser()
52
+ process(args)