File size: 944 Bytes
7e4b981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Union

from fairseq.data.dictionary import Dictionary

from .decoder_config import DecoderConfig, FlashlightDecoderConfig
from .base_decoder import BaseDecoder


def Decoder(
    cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary
) -> BaseDecoder:

    if cfg.type == "viterbi":
        from .viterbi_decoder import ViterbiDecoder

        return ViterbiDecoder(tgt_dict)
    if cfg.type == "kenlm":
        from .flashlight_decoder import KenLMDecoder

        return KenLMDecoder(cfg, tgt_dict)
    if cfg.type == "fairseqlm":
        from .flashlight_decoder import FairseqLMDecoder

        return FairseqLMDecoder(cfg, tgt_dict)
    raise NotImplementedError(f"Invalid decoder name: {cfg.name}")