OSUM / wenet /utils /init_model.py
tomxxie
适配zeroGPU
568e264
raw
history blame
8.58 kB
# Copyright (c) 2022 Binbin Zhang ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from wenet.finetune.lora.utils import (inject_lora_to_model,
mark_only_lora_as_trainable)
from wenet.k2.model import K2Model
from wenet.llm_asr.init_llmasr import init_llmasr
from wenet.paraformer.cif import Cif
from wenet.paraformer.layers import SanmDecoder, SanmEncoder
from wenet.paraformer.paraformer import Paraformer, Predictor
from wenet.LLM.causallm_model import CausalLM
from wenet.LLM.decoder import DecoderOnly
from wenet.ssl.init_model import WENET_SSL_MODEL_CLASS
from wenet.transducer.joint import TransducerJoint
from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor,
RNNPredictor)
from wenet.transducer.transducer import Transducer
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.cmvn import GlobalCMVN
from wenet.transformer.ctc import CTC
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder
from wenet.branchformer.encoder import BranchformerEncoder
from wenet.e_branchformer.encoder import EBranchformerEncoder
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.ctl_model.encoder import DualTransformerEncoder, DualConformerEncoder
from wenet.ctl_model.asr_model_ctl import CTLModel
from wenet.whisper.whisper import Whisper
from wenet.utils.cmvn import load_cmvn
from wenet.utils.checkpoint import load_checkpoint, load_trained_modules
WENET_ENCODER_CLASSES = {
"transformer": TransformerEncoder,
"conformer": ConformerEncoder,
"squeezeformer": SqueezeformerEncoder,
"efficientConformer": EfficientConformerEncoder,
"branchformer": BranchformerEncoder,
"e_branchformer": EBranchformerEncoder,
"dual_transformer": DualTransformerEncoder,
"dual_conformer": DualConformerEncoder,
'sanm_encoder': SanmEncoder,
}
WENET_DECODER_CLASSES = {
"transformer": TransformerDecoder,
"bitransformer": BiTransformerDecoder,
"sanm_decoder": SanmDecoder,
}
WENET_CTC_CLASSES = {
"ctc": CTC,
}
WENET_PREDICTOR_CLASSES = {
"rnn": RNNPredictor,
"embedding": EmbeddingPredictor,
"conv": ConvPredictor,
"cif_predictor": Cif,
"paraformer_predictor": Predictor,
}
WENET_JOINT_CLASSES = {
"transducer_joint": TransducerJoint,
}
WENET_MODEL_CLASSES = {
"asr_model": ASRModel,
"ctl_model": CTLModel,
"whisper": Whisper,
"k2_model": K2Model,
"transducer": Transducer,
'paraformer': Paraformer,
'causal_llm': CausalLM,
}
def init_speech_model(args, configs):
# TODO(xcsong): Forcefully read the 'cmvn' attribute.
if configs.get('cmvn', None) == 'global_cmvn':
mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'],
configs['cmvn_conf']['is_json_cmvn'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'conformer')
decoder_type = configs.get('decoder', 'bitransformer')
ctc_type = configs.get('ctc', 'ctc')
encoder = WENET_ENCODER_CLASSES[encoder_type](
input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'],
**configs['encoder_conf']['efficient_conf']
if 'efficient_conf' in configs['encoder_conf'] else {})
decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = WENET_CTC_CLASSES[ctc_type](
vocab_size,
encoder.output_size(),
blank_id=configs['ctc_conf']['ctc_blank_id']
if 'ctc_conf' in configs else 0)
model_type = configs.get('model', 'asr_model')
if model_type == "transducer":
predictor_type = configs.get('predictor', 'rnn')
joint_type = configs.get('joint', 'transducer_joint')
predictor = WENET_PREDICTOR_CLASSES[predictor_type](
vocab_size, **configs['predictor_conf'])
joint = WENET_JOINT_CLASSES[joint_type](vocab_size,
**configs['joint_conf'])
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
blank=0,
predictor=predictor,
encoder=encoder,
attention_decoder=decoder,
joint=joint,
ctc=ctc,
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
**configs['model_conf'])
elif model_type == 'paraformer':
predictor_type = configs.get('predictor', 'cif')
predictor = WENET_PREDICTOR_CLASSES[predictor_type](
**configs['predictor_conf'])
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
predictor=predictor,
ctc=ctc,
**configs['model_conf'],
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
)
elif model_type in WENET_SSL_MODEL_CLASS.keys():
from wenet.ssl.init_model import init_model as init_ssl_model
model = init_ssl_model(configs, encoder)
else:
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
**configs['model_conf'])
return model, configs
def init_causal_llm(configs):
vocab_size = configs['output_dim']
assert configs['decoder'] == 'decoder_only'
assert configs['model'] == 'causal_lm'
decoder_only = DecoderOnly(**configs['decoder_conf'])
model = CausalLM(
vocab_size,
decoder_only,
**configs['model_conf'],
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
)
return model, configs
def init_model(args, configs):
model_type = configs.get('model', 'asr_model')
configs['model'] = model_type
if model_type == 'causal_lm':
model, configs = init_causal_llm(configs)
elif model_type == "llmasr":
model = init_llmasr(args, configs)
return model
else:
model, configs = init_speech_model(args, configs)
if hasattr(args, 'use_lora') and args.use_lora:
inject_lora_to_model(model, configs['lora_conf'])
# If specify checkpoint, load some info from checkpoint
if hasattr(args, 'checkpoint') and args.checkpoint is not None:
infos = load_checkpoint(model, args.checkpoint)
elif hasattr(args, 'enc_init') and args.enc_init is not None:
infos = load_trained_modules(model, args)
else:
infos = {}
if configs.get('init_step', False):
infos = {}
configs["init_infos"] = infos
if hasattr(args, 'use_lora') and args.use_lora:
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path:
load_checkpoint(model, args.lora_ckpt_path)
print(configs)
# Trye to tie some weights
if hasattr(model, 'tie_or_clone_weights'):
if not hasattr(args, 'jit'):
args.jit = True # i.e. export onnx/jit/ipex
model.tie_or_clone_weights(args.jit)
if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora:
mark_only_lora_as_trainable(model, bias='lora_only')
if int(os.environ.get('RANK', 0)) == 0:
print(configs)
return model, configs