|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import subprocess |
|
import sys |
|
from dataclasses import dataclass, field |
|
from glob import glob |
|
from typing import List |
|
|
|
from omegaconf import MISSING |
|
from scripts.asr_language_modeling.ngram_lm import kenlm_utils |
|
|
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
""" |
|
NeMo's beam search decoders only support char-level encodings. In order to make it work with BPE-level encodings, we |
|
use a trick to encode the sub-word tokens of the training data as unicode characters and train a char-level KenLM. |
|
""" |
|
|
|
|
|
@dataclass |
|
class TrainKenlmConfig: |
|
""" |
|
Train an N-gram language model with KenLM to be used with beam search decoder of ASR models. |
|
""" |
|
|
|
train_paths: List[ |
|
str |
|
] = MISSING |
|
|
|
nemo_model_file: str = MISSING |
|
kenlm_model_file: str = MISSING |
|
ngram_length: int = MISSING |
|
kenlm_bin_path: str = MISSING |
|
|
|
preserve_arpa: bool = False |
|
ngram_prune: List[int] = field( |
|
default_factory=lambda: [0] |
|
) |
|
cache_path: str = "" |
|
verbose: int = 1 |
|
|
|
|
|
@hydra_runner(config_path=None, config_name='TrainKenlmConfig', schema=TrainKenlmConfig) |
|
def main(args: TrainKenlmConfig): |
|
train_paths = kenlm_utils.get_train_list(args.train_paths) |
|
|
|
if isinstance(args.ngram_prune, str): |
|
args.ngram_prune = [args.ngram_prune] |
|
|
|
tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(args.nemo_model_file) |
|
|
|
if encoding_level == "subword": |
|
discount_arg = "--discount_fallback" |
|
else: |
|
discount_arg = "" |
|
|
|
arpa_file = f"{args.kenlm_model_file}.tmp.arpa" |
|
""" LMPLZ ARGUMENT SETUP """ |
|
kenlm_args = [ |
|
os.path.join(args.kenlm_bin_path, 'lmplz'), |
|
"-o", |
|
str(args.ngram_length), |
|
"--arpa", |
|
arpa_file, |
|
discount_arg, |
|
"--prune", |
|
] + [str(n) for n in args.ngram_prune] |
|
|
|
if args.cache_path: |
|
if not os.path.exists(args.cache_path): |
|
os.makedirs(args.cache_path, exist_ok=True) |
|
|
|
""" DATASET SETUP """ |
|
encoded_train_files = [] |
|
for file_num, train_file in enumerate(train_paths): |
|
logging.info(f"Encoding the train file '{train_file}' number {file_num+1} out of {len(train_paths)} ...") |
|
|
|
cached_files = glob(os.path.join(args.cache_path, os.path.split(train_file)[1]) + "*") |
|
encoded_train_file = os.path.join(args.cache_path, os.path.split(train_file)[1] + f"_{file_num}.tmp.txt") |
|
if ( |
|
cached_files and cached_files[0] != encoded_train_file |
|
): |
|
os.rename(cached_files[0], encoded_train_file) |
|
logging.info("Rename", cached_files[0], "to", encoded_train_file) |
|
|
|
encoded_train_files.append(encoded_train_file) |
|
|
|
kenlm_utils.iter_files( |
|
source_path=train_paths, |
|
dest_path=encoded_train_files, |
|
tokenizer=tokenizer, |
|
encoding_level=encoding_level, |
|
is_aggregate_tokenizer=is_aggregate_tokenizer, |
|
verbose=args.verbose, |
|
) |
|
|
|
first_process_args = ["cat"] + encoded_train_files |
|
first_process = subprocess.Popen(first_process_args, stdout=subprocess.PIPE, stderr=sys.stderr) |
|
|
|
logging.info(f"Running lmplz command \n\n{' '.join(kenlm_args)}\n\n") |
|
kenlm_p = subprocess.run( |
|
kenlm_args, |
|
stdin=first_process.stdout, |
|
capture_output=False, |
|
text=True, |
|
stdout=sys.stdout, |
|
stderr=sys.stderr, |
|
) |
|
first_process.wait() |
|
|
|
else: |
|
logging.info(f"Running lmplz command \n\n{' '.join(kenlm_args)}\n\n") |
|
kenlm_p = subprocess.Popen(kenlm_args, stdout=sys.stdout, stdin=subprocess.PIPE, stderr=sys.stderr) |
|
|
|
kenlm_utils.iter_files( |
|
source_path=train_paths, |
|
dest_path=kenlm_p.stdin, |
|
tokenizer=tokenizer, |
|
encoding_level=encoding_level, |
|
is_aggregate_tokenizer=is_aggregate_tokenizer, |
|
verbose=args.verbose, |
|
) |
|
|
|
kenlm_p.communicate() |
|
|
|
if kenlm_p.returncode != 0: |
|
raise RuntimeError("Training KenLM was not successful!") |
|
|
|
""" BINARY BUILD """ |
|
|
|
kenlm_args = [ |
|
os.path.join(args.kenlm_bin_path, "build_binary"), |
|
"trie", |
|
arpa_file, |
|
args.kenlm_model_file, |
|
] |
|
logging.info(f"Running binary_build command \n\n{' '.join(kenlm_args)}\n\n") |
|
ret = subprocess.run(kenlm_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr) |
|
|
|
if ret.returncode != 0: |
|
raise RuntimeError("Training KenLM was not successful!") |
|
|
|
if not args.preserve_arpa: |
|
os.remove(arpa_file) |
|
logging.info(f"Deleted the arpa file '{arpa_file}'.") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|