amupd's picture
SpeechT5 upload
62e9ca6
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from argparse import Namespace
from omegaconf import II
import torch.nn as nn
from dataclasses import dataclass, field
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear
from fairseq.tasks import FairseqTask
@dataclass
class Speech2cAsrConfig(HubertAsrConfig):
# for decoder
decoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a decoder layer in hubert"},
)
add_decoder: bool = II("task.add_decoder")
@dataclass
class Speech2cCtcConfig(Speech2cAsrConfig):
pass
@register_model("speech2c_ctc", dataclass=Speech2cCtcConfig)
class Speech2cCtc(BaseFairseqModel):
def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary)
return cls(cfg, w2v_encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if "encoder_out" not in net_output:
return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
if "encoder_out_for_ctc" in net_output:
logits = net_output["encoder_out_for_ctc"]
else:
logits = net_output["encoder_out"]
if isinstance(logits, list):
logits = logits[0]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
logits = net_output["encoder_out"]
padding = net_output["encoder_padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float("-inf")
return logits
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
@property
def encoder(self):
return self.w2v_encoder
def reorder_encoder_out(self, encoder_out, new_order):
return self.encoder.reorder_encoder_out(encoder_out, new_order)
@property
def decoder(self):
return self.w2v_encoder.w2v_model.decoder
class Speech2cEncoder(FairseqEncoder):
def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None):
self.apply_mask = cfg.apply_mask
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"decoder_layerdrop": cfg.decoder_layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
"decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args.task.data = cfg.data
w2v_args.task.add_decoder = cfg.add_decoder
task = tasks.setup_task(w2v_args.task)
if state is not None and "task_state" in state:
# This will load the stored "dictionaries" object
task.load_state_dict(state["task_state"])
model = task.build_model(w2v_args.model)
if state is not None and not cfg.no_pretrained_weights:
if "decoder.embed_tokens.weight" in state["model"]:
del state["model"]["decoder.embed_tokens.weight"]
if "decoder.output_projection.weight" in state["model"]:
del state["model"]["decoder.output_projection.weight"]
# set strict=False because we omit some modules
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = model.mask_emb.size(0)
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
elif getattr(cfg, "decoder_embed_dim", d) != d:
self.proj = Linear(d, cfg.decoder_embed_dim)
else:
self.proj = None
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
ft = self.freeze_finetune_updates <= self.num_updates
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
"prev_output_tokens": prev_output_tokens,
"ft": ft,
}
x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if self.proj:
x = self.proj(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
"decoder_out": decoder_out,
}
def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
if isinstance(encoder_out["encoder_out"], list):
encoder_out["encoder_out"] = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
else:
encoder_out["encoder_out"] = encoder_out[
"encoder_out"
].index_select(1, new_order)
if encoder_out["encoder_padding_mask"] is not None:
if isinstance(encoder_out["encoder_padding_mask"], list):
encoder_out["encoder_padding_mask"] = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
else:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
if isinstance(encoder_out["decoder_out"], list):
encoder_out["decoder_out"] = (
[] if len(encoder_out["decoder_out"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
)
else:
encoder_out["decoder_out"] = encoder_out[
"decoder_out"
].index_select(0, new_order)
if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None:
if isinstance(encoder_out["encoder_out_for_ctc"], list):
encoder_out["encoder_out_for_ctc"] = (
[] if len(encoder_out["encoder_out_for_ctc"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]]
)
else:
encoder_out["encoder_out_for_ctc"] = encoder_out[
"encoder_out_for_ctc"
].index_select(1, new_order)
return encoder_out
def forward_torchscript(self, net_input):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
encoder_out = self.w2v_model.forward_torchscript(net_input)
assert self.proj is not None
encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])]
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict