# This module is from [WeNet](https://github.com/wenet-e2e/wenet).

# ## Citations

# ```bibtex
# @inproceedings{yao2021wenet,
#   title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
#   author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
#   booktitle={Proc. Interspeech},
#   year={2021},
#   address={Brno, Czech Republic },
#   organization={IEEE}
# }

# @article{zhang2022wenet,
#   title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
#   author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
#   journal={arXiv preprint arXiv:2203.15455},
#   year={2022}
# }
#

import torch
from modules.wenet_extractor.transducer.joint import TransducerJoint
from modules.wenet_extractor.transducer.predictor import (
    ConvPredictor,
    EmbeddingPredictor,
    RNNPredictor,
)
from modules.wenet_extractor.transducer.transducer import Transducer
from modules.wenet_extractor.transformer.asr_model import ASRModel
from modules.wenet_extractor.transformer.cmvn import GlobalCMVN
from modules.wenet_extractor.transformer.ctc import CTC
from modules.wenet_extractor.transformer.decoder import (
    BiTransformerDecoder,
    TransformerDecoder,
)
from modules.wenet_extractor.transformer.encoder import (
    ConformerEncoder,
    TransformerEncoder,
)
from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder
from modules.wenet_extractor.efficient_conformer.encoder import (
    EfficientConformerEncoder,
)
from modules.wenet_extractor.paraformer.paraformer import Paraformer
from modules.wenet_extractor.cif.predictor import Predictor
from modules.wenet_extractor.utils.cmvn import load_cmvn


def init_model(configs):
    if configs["cmvn_file"] is not None:
        mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
        global_cmvn = GlobalCMVN(
            torch.from_numpy(mean).float(), torch.from_numpy(istd).float()
        )
    else:
        global_cmvn = None

    input_dim = configs["input_dim"]
    vocab_size = configs["output_dim"]

    encoder_type = configs.get("encoder", "conformer")
    decoder_type = configs.get("decoder", "bitransformer")

    if encoder_type == "conformer":
        encoder = ConformerEncoder(
            input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
        )
    elif encoder_type == "squeezeformer":
        encoder = SqueezeformerEncoder(
            input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
        )
    elif encoder_type == "efficientConformer":
        encoder = EfficientConformerEncoder(
            input_dim,
            global_cmvn=global_cmvn,
            **configs["encoder_conf"],
            **configs["encoder_conf"]["efficient_conf"]
            if "efficient_conf" in configs["encoder_conf"]
            else {},
        )
    else:
        encoder = TransformerEncoder(
            input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
        )
    if decoder_type == "transformer":
        decoder = TransformerDecoder(
            vocab_size, encoder.output_size(), **configs["decoder_conf"]
        )
    else:
        assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0
        assert configs["decoder_conf"]["r_num_blocks"] > 0
        decoder = BiTransformerDecoder(
            vocab_size, encoder.output_size(), **configs["decoder_conf"]
        )
    ctc = CTC(vocab_size, encoder.output_size())

    # Init joint CTC/Attention or Transducer model
    if "predictor" in configs:
        predictor_type = configs.get("predictor", "rnn")
        if predictor_type == "rnn":
            predictor = RNNPredictor(vocab_size, **configs["predictor_conf"])
        elif predictor_type == "embedding":
            predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"])
            configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
                "embed_size"
            ]
        elif predictor_type == "conv":
            predictor = ConvPredictor(vocab_size, **configs["predictor_conf"])
            configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
                "embed_size"
            ]
        else:
            raise NotImplementedError("only rnn, embedding and conv type support now")
        configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][
            "output_size"
        ]
        configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][
            "output_size"
        ]
        joint = TransducerJoint(vocab_size, **configs["joint_conf"])
        model = Transducer(
            vocab_size=vocab_size,
            blank=0,
            predictor=predictor,
            encoder=encoder,
            attention_decoder=decoder,
            joint=joint,
            ctc=ctc,
            **configs["model_conf"],
        )
    elif "paraformer" in configs:
        predictor = Predictor(**configs["cif_predictor_conf"])
        model = Paraformer(
            vocab_size=vocab_size,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            predictor=predictor,
            **configs["model_conf"],
        )
    else:
        model = ASRModel(
            vocab_size=vocab_size,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            lfmmi_dir=configs.get("lfmmi_dir", ""),
            **configs["model_conf"],
        )
    return model