Spaces:
Runtime error
Runtime error
GreenRaptor
commited on
Commit
·
468ac2e
1
Parent(s):
9e43625
Create mms_infer.py
Browse files- mms_infer.py +52 -0
mms_infer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import soundfile as sf
|
9 |
+
import tempfile
|
10 |
+
from pathlib import Path
|
11 |
+
import os
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
import re
|
15 |
+
|
16 |
+
def parser():
|
17 |
+
parser = argparse.ArgumentParser(description="ASR inference script for MMS model")
|
18 |
+
parser.add_argument("--model", type=str, help="path to ASR model", required=True)
|
19 |
+
parser.add_argument("--audio", type=str, help="path to audio file", required=True, nargs='+')
|
20 |
+
parser.add_argument("--lang", type=str, help="audio language", required=True)
|
21 |
+
parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter")
|
22 |
+
return parser.parse_args()
|
23 |
+
|
24 |
+
def process(args):
|
25 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
26 |
+
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
|
27 |
+
tmpdir = Path(tmpdir)
|
28 |
+
with open(tmpdir / "dev.tsv", "w") as fw:
|
29 |
+
fw.write("/\n")
|
30 |
+
for audio in args.audio:
|
31 |
+
nsample = sf.SoundFile(audio).frames
|
32 |
+
fw.write(f"{audio}\t{nsample}\n")
|
33 |
+
with open(tmpdir / "dev.uid", "w") as fw:
|
34 |
+
fw.write(f"{audio}\n"*len(args.audio))
|
35 |
+
with open(tmpdir / "dev.ltr", "w") as fw:
|
36 |
+
fw.write("d u m m y | d u m m y\n"*len(args.audio))
|
37 |
+
with open(tmpdir / "dev.wrd", "w") as fw:
|
38 |
+
fw.write("dummy dummy\n"*len(args.audio))
|
39 |
+
cmd = f"""
|
40 |
+
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir}
|
41 |
+
"""
|
42 |
+
print(">>> loading model & running inference ...", file=sys.stderr)
|
43 |
+
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
|
44 |
+
with open(tmpdir/"hypo.word") as fr:
|
45 |
+
for ii, hypo in enumerate(fr):
|
46 |
+
hypo = re.sub("\(\S+\)$", "", hypo).strip()
|
47 |
+
print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
args = parser()
|
52 |
+
process(args)
|