Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022 Binbin Zhang ([email protected]) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import torch | |
from wenet.finetune.lora.utils import (inject_lora_to_model, | |
mark_only_lora_as_trainable) | |
from wenet.k2.model import K2Model | |
from wenet.llm_asr.init_llmasr import init_llmasr | |
from wenet.paraformer.cif import Cif | |
from wenet.paraformer.layers import SanmDecoder, SanmEncoder | |
from wenet.paraformer.paraformer import Paraformer, Predictor | |
from wenet.LLM.causallm_model import CausalLM | |
from wenet.LLM.decoder import DecoderOnly | |
from wenet.ssl.init_model import WENET_SSL_MODEL_CLASS | |
from wenet.transducer.joint import TransducerJoint | |
from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, | |
RNNPredictor) | |
from wenet.transducer.transducer import Transducer | |
from wenet.transformer.asr_model import ASRModel | |
from wenet.transformer.cmvn import GlobalCMVN | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder | |
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder | |
from wenet.branchformer.encoder import BranchformerEncoder | |
from wenet.e_branchformer.encoder import EBranchformerEncoder | |
from wenet.squeezeformer.encoder import SqueezeformerEncoder | |
from wenet.efficient_conformer.encoder import EfficientConformerEncoder | |
from wenet.ctl_model.encoder import DualTransformerEncoder, DualConformerEncoder | |
from wenet.ctl_model.asr_model_ctl import CTLModel | |
from wenet.whisper.whisper import Whisper | |
from wenet.utils.cmvn import load_cmvn | |
from wenet.utils.checkpoint import load_checkpoint, load_trained_modules | |
WENET_ENCODER_CLASSES = { | |
"transformer": TransformerEncoder, | |
"conformer": ConformerEncoder, | |
"squeezeformer": SqueezeformerEncoder, | |
"efficientConformer": EfficientConformerEncoder, | |
"branchformer": BranchformerEncoder, | |
"e_branchformer": EBranchformerEncoder, | |
"dual_transformer": DualTransformerEncoder, | |
"dual_conformer": DualConformerEncoder, | |
'sanm_encoder': SanmEncoder, | |
} | |
WENET_DECODER_CLASSES = { | |
"transformer": TransformerDecoder, | |
"bitransformer": BiTransformerDecoder, | |
"sanm_decoder": SanmDecoder, | |
} | |
WENET_CTC_CLASSES = { | |
"ctc": CTC, | |
} | |
WENET_PREDICTOR_CLASSES = { | |
"rnn": RNNPredictor, | |
"embedding": EmbeddingPredictor, | |
"conv": ConvPredictor, | |
"cif_predictor": Cif, | |
"paraformer_predictor": Predictor, | |
} | |
WENET_JOINT_CLASSES = { | |
"transducer_joint": TransducerJoint, | |
} | |
WENET_MODEL_CLASSES = { | |
"asr_model": ASRModel, | |
"ctl_model": CTLModel, | |
"whisper": Whisper, | |
"k2_model": K2Model, | |
"transducer": Transducer, | |
'paraformer': Paraformer, | |
'causal_llm': CausalLM, | |
} | |
def init_speech_model(args, configs): | |
# TODO(xcsong): Forcefully read the 'cmvn' attribute. | |
if configs.get('cmvn', None) == 'global_cmvn': | |
mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], | |
configs['cmvn_conf']['is_json_cmvn']) | |
global_cmvn = GlobalCMVN( | |
torch.from_numpy(mean).float(), | |
torch.from_numpy(istd).float()) | |
else: | |
global_cmvn = None | |
input_dim = configs['input_dim'] | |
vocab_size = configs['output_dim'] | |
encoder_type = configs.get('encoder', 'conformer') | |
decoder_type = configs.get('decoder', 'bitransformer') | |
ctc_type = configs.get('ctc', 'ctc') | |
encoder = WENET_ENCODER_CLASSES[encoder_type]( | |
input_dim, | |
global_cmvn=global_cmvn, | |
**configs['encoder_conf'], | |
**configs['encoder_conf']['efficient_conf'] | |
if 'efficient_conf' in configs['encoder_conf'] else {}) | |
decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, | |
encoder.output_size(), | |
**configs['decoder_conf']) | |
ctc = WENET_CTC_CLASSES[ctc_type]( | |
vocab_size, | |
encoder.output_size(), | |
blank_id=configs['ctc_conf']['ctc_blank_id'] | |
if 'ctc_conf' in configs else 0) | |
model_type = configs.get('model', 'asr_model') | |
if model_type == "transducer": | |
predictor_type = configs.get('predictor', 'rnn') | |
joint_type = configs.get('joint', 'transducer_joint') | |
predictor = WENET_PREDICTOR_CLASSES[predictor_type]( | |
vocab_size, **configs['predictor_conf']) | |
joint = WENET_JOINT_CLASSES[joint_type](vocab_size, | |
**configs['joint_conf']) | |
model = WENET_MODEL_CLASSES[model_type]( | |
vocab_size=vocab_size, | |
blank=0, | |
predictor=predictor, | |
encoder=encoder, | |
attention_decoder=decoder, | |
joint=joint, | |
ctc=ctc, | |
special_tokens=configs.get('tokenizer_conf', | |
{}).get('special_tokens', None), | |
**configs['model_conf']) | |
elif model_type == 'paraformer': | |
predictor_type = configs.get('predictor', 'cif') | |
predictor = WENET_PREDICTOR_CLASSES[predictor_type]( | |
**configs['predictor_conf']) | |
model = WENET_MODEL_CLASSES[model_type]( | |
vocab_size=vocab_size, | |
encoder=encoder, | |
decoder=decoder, | |
predictor=predictor, | |
ctc=ctc, | |
**configs['model_conf'], | |
special_tokens=configs.get('tokenizer_conf', | |
{}).get('special_tokens', None), | |
) | |
elif model_type in WENET_SSL_MODEL_CLASS.keys(): | |
from wenet.ssl.init_model import init_model as init_ssl_model | |
model = init_ssl_model(configs, encoder) | |
else: | |
model = WENET_MODEL_CLASSES[model_type]( | |
vocab_size=vocab_size, | |
encoder=encoder, | |
decoder=decoder, | |
ctc=ctc, | |
special_tokens=configs.get('tokenizer_conf', | |
{}).get('special_tokens', None), | |
**configs['model_conf']) | |
return model, configs | |
def init_causal_llm(configs): | |
vocab_size = configs['output_dim'] | |
assert configs['decoder'] == 'decoder_only' | |
assert configs['model'] == 'causal_lm' | |
decoder_only = DecoderOnly(**configs['decoder_conf']) | |
model = CausalLM( | |
vocab_size, | |
decoder_only, | |
**configs['model_conf'], | |
special_tokens=configs.get('tokenizer_conf', | |
{}).get('special_tokens', None), | |
) | |
return model, configs | |
def init_model(args, configs): | |
model_type = configs.get('model', 'asr_model') | |
configs['model'] = model_type | |
if model_type == 'causal_lm': | |
model, configs = init_causal_llm(configs) | |
elif model_type == "llmasr": | |
model = init_llmasr(args, configs) | |
return model | |
else: | |
model, configs = init_speech_model(args, configs) | |
if hasattr(args, 'use_lora') and args.use_lora: | |
inject_lora_to_model(model, configs['lora_conf']) | |
# If specify checkpoint, load some info from checkpoint | |
if hasattr(args, 'checkpoint') and args.checkpoint is not None: | |
infos = load_checkpoint(model, args.checkpoint) | |
elif hasattr(args, 'enc_init') and args.enc_init is not None: | |
infos = load_trained_modules(model, args) | |
else: | |
infos = {} | |
if configs.get('init_step', False): | |
infos = {} | |
configs["init_infos"] = infos | |
if hasattr(args, 'use_lora') and args.use_lora: | |
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path: | |
load_checkpoint(model, args.lora_ckpt_path) | |
print(configs) | |
# Trye to tie some weights | |
if hasattr(model, 'tie_or_clone_weights'): | |
if not hasattr(args, 'jit'): | |
args.jit = True # i.e. export onnx/jit/ipex | |
model.tie_or_clone_weights(args.jit) | |
if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: | |
mark_only_lora_as_trainable(model, bias='lora_only') | |
if int(os.environ.get('RANK', 0)) == 0: | |
print(configs) | |
return model, configs | |