File size: 18,367 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This script would interpolate two arpa N-gram language models (LMs), 
culculate perplexity of resulted LM, and make binary KenLM from it.

Minimun usage example to interpolate two N-gram language models with weights:
alpha * ngram_a + beta * ngram_b = 2 * ngram_a + 1 * ngram_b

python3 ngram_merge.py  --kenlm_bin_path /workspace/nemo/decoders/kenlm/build/bin \
                    --arpa_a /path/ngram_a.kenlm.tmp.arpa \
                    --alpha 2 \
                    --arpa_b /path/ngram_b.kenlm.tmp.arpa \
                    --beta 1 \
                    --out_path /path/out


Merge two N-gram language models and calculate its perplexity with test_file.
python3 ngram_merge.py  --kenlm_bin_path /workspace/nemo/decoders/kenlm/build/bin \
                    --ngram_bin_path /workspace/nemo/decoders/ngram-1.3.14/src/bin \
                    --arpa_a /path/ngram_a.kenlm.tmp.arpa \
                    --alpha 0.5 \
                    --arpa_b /path/ngram_b.kenlm.tmp.arpa \
                    --beta 0.5 \
                    --out_path /path/out \
                    --nemo_model_file /path/to/model_tokenizer.nemo \
                    --test_file /path/to/test_manifest.json \
                    --force
"""

import argparse
import os
import subprocess
import sys
from typing import Tuple

import kenlm_utils
import torch

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.modules.rnnt import RNNTDecoder
from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET
from nemo.utils import logging


class NgramMerge:
    def __init__(self, ngram_bin_path):
        self.ngram_bin_path = ngram_bin_path

    def ngrammerge(self, arpa_a: str, alpha: float, arpa_b: str, beta: float, arpa_c: str, force: bool) -> str:
        """
        Merge two ARPA n-gram language models using the ngrammerge command-line tool and output the result in ARPA format.
        
        Args:
            arpa_a (str): Path to the first input ARPA file.
            alpha (float): Interpolation weight for the first model.
            arpa_b (str): Path to the second input ARPA file.
            beta (float): Interpolation weight for the second model.
            arpa_c (str): Path to the output ARPA file.
            force (bool): Whether to overwrite existing output files.
        
        Returns:
            str: Path to the output ARPA file in mod format.
        """
        mod_a = arpa_a + ".mod"
        mod_b = arpa_b + ".mod"
        mod_c = arpa_c + ".mod"
        if os.path.isfile(mod_c) and not force:
            logging.info("File " + mod_c + " exists. Skipping.")
        else:
            sh_args = [
                os.path.join(self.ngram_bin_path, "ngrammerge"),
                "--alpha=" + str(alpha),
                "--beta=" + str(beta),
                "--normalize",
                # "--use_smoothing",
                mod_a,
                mod_b,
                mod_c,
            ]
            logging.info(
                "\n"
                + str(subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,))
                + "\n",
            )
        return mod_c

    def arpa2mod(self, arpa_path: str, force: bool):
        """
        This function reads an ARPA n-gram model and converts it to a binary format. The binary model is saved to the same directory as the ARPA model with a ".mod" extension. If the binary model file already exists and force argument is False, then the function skips conversion and returns a message. Otherwise, it executes the command to create a binary model using the subprocess.run method.

        Parameters:
            arpa_path (string): The file path to the ARPA n-gram model.
            force (bool): If True, the function will convert the ARPA model to binary even if the binary file already exists. If False and the binary file exists, the function will skip the conversion.
        Returns:
            If the binary model file already exists and force argument is False, returns a message indicating that the file exists and the conversion is skipped.
            Otherwise, returns a subprocess.CompletedProcess object, which contains information about the executed command. The subprocess's output and error streams are redirected to stdout and stderr, respectively.
        """
        mod_path = arpa_path + ".mod"
        if os.path.isfile(mod_path) and not force:
            return "File " + mod_path + " exists. Skipping."
        else:
            sh_args = [
                os.path.join(self.ngram_bin_path, "ngramread"),
                "--ARPA",
                arpa_path,
                mod_path,
            ]
            return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,)

    def merge(
        self, arpa_a: str, alpha: float, arpa_b: str, beta: float, out_path: str, force: bool
    ) -> Tuple[str, str]:
        """
        Merges two ARPA language models using the ngrammerge tool.

        Args:
            arpa_a (str): Path to the first ARPA language model file.
            alpha (float): Interpolation weight for the first model.
            arpa_b (str): Path to the second ARPA language model file.
            beta (float): Interpolation weight for the second model.
            out_path (str): Path to the output directory for the merged ARPA model.
            force (bool): Whether to force overwrite of existing files.

        Returns:
            Tuple[str, str]: A tuple containing the path to the merged binary language model file and the path to the 
            merged ARPA language model file.
        """
        logging.info("\n" + str(self.arpa2mod(arpa_a, force)) + "\n")

        logging.info("\n" + str(self.arpa2mod(arpa_b, force)) + "\n")
        arpa_c = os.path.join(out_path, f"{os.path.split(arpa_a)[1]}-{alpha}-{os.path.split(arpa_b)[1]}-{beta}.arpa",)
        mod_c = self.ngrammerge(arpa_a, alpha, arpa_b, beta, arpa_c, force)
        return mod_c, arpa_c

    def perplexity(self, ngram_mod: str, test_far: str) -> str:
        """
        Calculates perplexity of a given ngram model on a test file.

        Args:
            ngram_mod (str): The path to the ngram model file.
            test_far (str): The path to the test file.

        Returns:
            str: A string representation of the perplexity calculated.

        Raises:
            AssertionError: If the subprocess to calculate perplexity returns a non-zero exit code.

        Example:
            >>> perplexity("/path/to/ngram_model", "/path/to/test_file")
            'Perplexity: 123.45'
        """
        sh_args = [
            os.path.join(self.ngram_bin_path, "ngramperplexity"),
            "--v=1",
            ngram_mod,
            test_far,
        ]
        ps = subprocess.Popen(sh_args, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        stdout, stderr = ps.communicate()
        exit_code = ps.wait()
        command = " ".join(sh_args)
        assert (
            exit_code == 0
        ), f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}"
        perplexity_out = "\n".join(stdout.split("\n")[-6:-1])
        return perplexity_out

    def make_arpa(self, ngram_mod: str, ngram_arpa: str, force: bool):
        """
        Converts an ngram model in binary format to ARPA format.

        Args:
        - ngram_mod (str): The path to the ngram model in binary format.
        - ngram_arpa (str): The desired path for the ARPA format output file.
        - force (bool): If True, the ARPA format file will be generated even if it already exists.

        Returns:
        - Tuple[bytes, bytes]

        Raises:
        - AssertionError: If the shell command execution returns a non-zero exit code.
        - FileNotFoundError: If the binary ngram model file does not exist.
        """
        if os.path.isfile(ngram_arpa) and not force:
            logging.info("File " + ngram_arpa + " exists. Skipping.")
            return None
        else:
            sh_args = [
                os.path.join(self.ngram_bin_path, "ngramprint"),
                "--ARPA",
                ngram_mod,
                ngram_arpa,
            ]
            return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,)

    def test_perplexity(self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str) -> str:
        """
        Tests the perplexity of a given ngram model on a test file.

        Args:
            mod_c (str): The path to the ngram model file.
            symbols (str): The path to the symbol table file.
            test_txt (str): The path to the test text file.
            nemo_model_file (str): The path to the NeMo model file.
            tmp_path (str): The path to the temporary directory where the test far file will be created.
            force (bool): If True, overwrites any existing far file.

        Returns:
            str: A string representation of the perplexity calculated.

        Example:
            >>> test_perplexity("/path/to/ngram_model", "/path/to/symbol_table", "/path/to/test_file", "/path/to/tokenizer_model", "/path/to/tmp_dir", True)
            'Perplexity: 123.45'
        """

        test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file)
        res_p = self.perplexity(mod_c, test_far)
        return res_p


def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str) -> str:
    """
    Compiles a text file into a FAR file using the given symbol table or tokenizer.

    Args:
        symbols (str): The path to the symbol table file.
        text_file (str): The path to the text file to compile.
        tmp_path (str): The path to the temporary directory where the test far file will be created.
        nemo_model_file (str): The path to the NeMo model file (.nemo).
        force (bool): If True, overwrites any existing FAR file.

    Returns:
        test_far (str): The path to the resulting FAR file.

    Example:
        >>> farcompile("/path/to/symbol_table", "/path/to/text_file", "/path/to/far_file", "/path/to/tokenizer_model", "/path/to/nemo_model", True)
    """
    test_far = os.path.join(tmp_path, os.path.split(text_file)[1] + ".far")

    sh_args = [
        "farcompilestrings",
        "--generate_keys=10",
        "--fst_type=compact",
        "--symbols=" + symbols,
        "--keep_symbols",
        ">",
        test_far,
    ]

    tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file)

    ps = subprocess.Popen(" ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr,)

    kenlm_utils.iter_files(
        source_path=[text_file],
        dest_path=ps.stdin,
        tokenizer=tokenizer,
        encoding_level=encoding_level,
        is_aggregate_tokenizer=is_aggregate_tokenizer,
        verbose=1,
    )
    stdout, stderr = ps.communicate()

    exit_code = ps.returncode

    command = " ".join(sh_args)
    assert exit_code == 0, f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}"
    return test_far


def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool):
    """
    Builds a language model from an ARPA format file using the KenLM toolkit.

    Args:
    - kenlm_bin_path (str): The path to the KenLM toolkit binary.
    - ngram_arpa (str): The path to the ARPA format file.
    - force (bool): If True, the KenLM language model will be generated even if it already exists.

    Raises:
    - AssertionError: If the shell command execution returns a non-zero exit code.
    - FileNotFoundError: If the KenLM binary or ARPA format file does not exist.
    """
    ngram_kenlm = ngram_arpa + ".kenlm"
    if os.path.isfile(ngram_kenlm) and not force:
        logging.info("File " + ngram_kenlm + " exists. Skipping.")
        return None
    else:
        sh_args = [os.path.join(kenlm_bin_path, "build_binary"), "trie", "-i", ngram_arpa, ngram_kenlm]
        return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,)


def make_symbol_list(nemo_model_file, symbols, force):
    """
    Function: make_symbol_list

    Create a symbol table for the input tokenizer model file.

    Args:
        nemo_model_file (str): Path to the NeMo model file.
        symbols (str): Path to the file where symbol list will be saved.
        force (bool): Flag to force creation of symbol list even if it already exists.
    
    Returns:
        None

    Raises:
        None
    """
    if os.path.isfile(symbols) and not force:
        logging.info("File " + symbols + " exists. Skipping.")
    else:
        if nemo_model_file.endswith('.nemo'):
            asr_model = nemo_asr.models.ASRModel.restore_from(nemo_model_file, map_location=torch.device('cpu'))
        else:
            logging.warning(
                "nemo_model_file does not end with .nemo, therefore trying to load a pretrained model with this name."
            )
            asr_model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu'))

        if isinstance(asr_model.decoder, RNNTDecoder):
            vocab_size = asr_model.decoder.blank_idx
        else:
            vocab_size = len(asr_model.decoder.vocabulary)

        vocab = [chr(idx + DEFAULT_TOKEN_OFFSET) for idx in range(vocab_size)]
        with open(symbols, "w", encoding="utf-8") as f:
            for i, v in enumerate(vocab):
                f.write(v + " " + str(i) + "\n")


def main(
    kenlm_bin_path: str,
    ngram_bin_path: str,
    arpa_a: str,
    alpha: float,
    arpa_b: str,
    beta: float,
    out_path: str,
    test_file: str,
    symbols: str,
    nemo_model_file: str,
    force: bool,
) -> None:
    """
    Entry point function for merging ARPA format language models, testing perplexity, creating symbol list, 
    and making ARPA and Kenlm models.

    Args:
    - kenlm_bin_path (str): The path to the Kenlm binary.
    - arpa_a (str): The path to the first ARPA format language model.
    - alpha (float): The weight given to the first language model during merging.
    - arpa_b (str): The path to the second ARPA format language model.
    - beta (float): The weight given to the second language model during merging.
    - out_path (str): The path where the output files will be saved.
    - test_file (str): The path to the file on which perplexity needs to be calculated.
    - symbols (str): The path to the file where symbol list for the tokenizer model will be saved.
    - nemo_model_file (str): The path to the NeMo model file.
    - force (bool): If True, overwrite existing files, otherwise skip the operations.

    Returns:
    - None
    """
    nm = NgramMerge(ngram_bin_path)
    mod_c, arpa_c = nm.merge(arpa_a, alpha, arpa_b, beta, out_path, force)

    if test_file and nemo_model_file:
        if not symbols:
            symbols = os.path.join(out_path, os.path.split(nemo_model_file)[1] + ".syms")
            make_symbol_list(nemo_model_file, symbols, force)
        for test_f in test_file.split(","):
            test_p = nm.test_perplexity(mod_c, symbols, test_f, nemo_model_file, out_path)
            logging.info("Perplexity summary " + test_f + " : " + test_p)

    logging.info("Making ARPA and Kenlm model " + arpa_c)
    out = nm.make_arpa(mod_c, arpa_c, force)
    if out:
        logging.info("\n" + str(out) + "\n")

    out = make_kenlm(kenlm_bin_path, arpa_c, force)
    if out:
        logging.info("\n" + str(out) + "\n")


def _parse_args():
    parser = argparse.ArgumentParser(
        description="Interpolate ARPA N-gram language models and make KenLM binary model to be used with beam search decoder of ASR models."
    )
    parser.add_argument(
        "--kenlm_bin_path", required=True, type=str, help="The path to the bin folder of KenLM library.",
    )  # Use /workspace/nemo/decoders/kenlm/build/bin if installed it with scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh
    parser.add_argument(
        "--ngram_bin_path", required=True, type=str, help="The path to the bin folder of OpenGrm Ngram library.",
    )  # Use /workspace/nemo/decoders/ngram-1.3.14/src/bin if installed it with scripts/installers/install_opengrm.sh
    parser.add_argument("--arpa_a", required=True, type=str, help="Path to the arpa_a")
    parser.add_argument("--alpha", required=True, type=float, help="Weight of arpa_a")
    parser.add_argument("--arpa_b", required=True, type=str, help="Path to the arpa_b")
    parser.add_argument("--beta", required=True, type=float, help="Weight of arpa_b")
    parser.add_argument(
        "--out_path", required=True, type=str, help="Path to write tmp and resulted files.",
    )
    parser.add_argument(
        "--test_file",
        required=False,
        type=str,
        default=None,
        help="Path to test file to count perplexity if provided.",
    )
    parser.add_argument(
        "--symbols",
        required=False,
        type=str,
        default=None,
        help="Path to symbols (.syms) file . Could be calculated if it is not provided. Use as: --symbols /path/to/earnest.syms",
    )
    parser.add_argument(
        "--nemo_model_file",
        required=False,
        type=str,
        default=None,
        help="The path to '.nemo' file of the ASR model, or name of a pretrained NeMo model",
    )
    parser.add_argument("--force", "-f", action="store_true", help="Whether to recompile and rewrite all files")
    return parser.parse_args()


if __name__ == "__main__":
    main(**vars(_parse_args()))