# -------------------------------------------------------- # Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113) # Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on fairseq code bases # https://github.com/pytorch/fairseq # -------------------------------------------------------- from argparse import Namespace from omegaconf import II import torch.nn as nn from dataclasses import dataclass, field from fairseq import checkpoint_utils, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear from fairseq.tasks import FairseqTask @dataclass class Speech2cAsrConfig(HubertAsrConfig): # for decoder decoder_layerdrop: float = field( default=0.0, metadata={"help": "probability of dropping a decoder layer in hubert"}, ) add_decoder: bool = II("task.add_decoder") @dataclass class Speech2cCtcConfig(Speech2cAsrConfig): pass @register_model("speech2c_ctc", dataclass=Speech2cCtcConfig) class Speech2cCtc(BaseFairseqModel): def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel): super().__init__() self.cfg = cfg self.w2v_encoder = w2v_encoder def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) return state_dict @classmethod def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask): """Build a new model instance.""" w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary) return cls(cfg, w2v_encoder) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" if "encoder_out" not in net_output: return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample) if "encoder_out_for_ctc" in net_output: logits = net_output["encoder_out_for_ctc"] else: logits = net_output["encoder_out"] if isinstance(logits, list): logits = logits[0] if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1) def get_logits(self, net_output): logits = net_output["encoder_out"] padding = net_output["encoder_padding_mask"] if padding is not None and padding.any(): padding = padding.T logits[padding][..., 0] = 0 logits[padding][..., 1:] = float("-inf") return logits def forward(self, **kwargs): x = self.w2v_encoder(**kwargs) return x @property def encoder(self): return self.w2v_encoder def reorder_encoder_out(self, encoder_out, new_order): return self.encoder.reorder_encoder_out(encoder_out, new_order) @property def decoder(self): return self.w2v_encoder.w2v_model.decoder class Speech2cEncoder(FairseqEncoder): def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "decoder_layerdrop": cfg.decoder_layerdrop, "feature_grad_mult": cfg.feature_grad_mult, "decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1, } if cfg.w2v_args is None: state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for " "both pre-training and here" ) w2v_args.task.data = cfg.data w2v_args.task.add_decoder = cfg.add_decoder task = tasks.setup_task(w2v_args.task) if state is not None and "task_state" in state: # This will load the stored "dictionaries" object task.load_state_dict(state["task_state"]) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: if "decoder.embed_tokens.weight" in state["model"]: del state["model"]["decoder.embed_tokens.weight"] if "decoder.output_projection.weight" in state["model"]: del state["model"]["decoder.output_projection.weight"] # set strict=False because we omit some modules model.load_state_dict(state["model"], strict=False) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = model.mask_emb.size(0) self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None def set_num_updates(self, num_updates): """Set the number of parameters updates.""" super().set_num_updates(num_updates) self.num_updates = num_updates def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs): ft = self.freeze_finetune_updates <= self.num_updates w2v_args = { "source": source, "padding_mask": padding_mask, "mask": self.apply_mask and self.training, "prev_output_tokens": prev_output_tokens, "ft": ft, } x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args) if tbc: # B x T x C -> T x B x C x = x.transpose(0, 1) x = self.final_dropout(x) if self.proj: x = self.proj(x) return { "encoder_out": x, # T x B x C "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, "decoder_out": decoder_out, } def get_normalized_probs_decoder(self, net_output, log_probs, sample=None): # net_output['encoder_out'] is a (B, T, D) tensor return self.w2v_model.get_normalized_probs(net_output, log_probs, sample) def reorder_encoder_out(self, encoder_out, new_order): if encoder_out["encoder_out"] is not None: if isinstance(encoder_out["encoder_out"], list): encoder_out["encoder_out"] = ( [] if len(encoder_out["encoder_out"]) == 0 else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] ) else: encoder_out["encoder_out"] = encoder_out[ "encoder_out" ].index_select(1, new_order) if encoder_out["encoder_padding_mask"] is not None: if isinstance(encoder_out["encoder_padding_mask"], list): encoder_out["encoder_padding_mask"] = ( [] if len(encoder_out["encoder_padding_mask"]) == 0 else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]] ) else: encoder_out["encoder_padding_mask"] = encoder_out[ "encoder_padding_mask" ].index_select(0, new_order) if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None: if isinstance(encoder_out["decoder_out"], list): encoder_out["decoder_out"] = ( [] if len(encoder_out["decoder_out"]) == 0 else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]] ) else: encoder_out["decoder_out"] = encoder_out[ "decoder_out" ].index_select(0, new_order) if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None: if isinstance(encoder_out["encoder_out_for_ctc"], list): encoder_out["encoder_out_for_ctc"] = ( [] if len(encoder_out["encoder_out_for_ctc"]) == 0 else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]] ) else: encoder_out["encoder_out_for_ctc"] = encoder_out[ "encoder_out_for_ctc" ].index_select(1, new_order) return encoder_out def forward_torchscript(self, net_input): """A TorchScript-compatible version of forward. Encoders which use additional arguments may want to override this method for TorchScript compatibility. """ encoder_out = self.w2v_model.forward_torchscript(net_input) assert self.proj is not None encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])] return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" return None def upgrade_state_dict_named(self, state_dict, name): return state_dict