amupd's picture
SpeechT5 upload
62e9ca6
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import logging
import contextlib
import pickle
from argparse import Namespace
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.hubert.hubert_asr import HubertCtcConfig, HubertAsrConfig
from fairseq.tasks import FairseqTask
from fairseq.data.data_utils import lengths_to_padding_mask
from omegaconf import II, open_dict
logger = logging.getLogger(__name__)
@dataclass
class HubertSt2tCtcConfig(HubertCtcConfig):
load_speech_only: bool = II("task.load_speech_only")
## for decoder overrides
decoder_layerdrop: float = field(
default=0.1,
metadata={"help": "probability of dropping a decoder layer in hubert"},
)
add_decoder: bool = field(
default=False,
metadata={"help": "whether to add decoder for CE Loss on code"},
)
reuse_text_emb: bool = field(
default=False,
metadata={"help": "reuse text token embeddings instead of initialize randomly"},
)
freeze_decoder_updates: int = field(
default=0,
metadata={"help": "dont finetune hubert for this many updates"},
)
# share_enc_dec_embeddings: bool = field(
# default=False,
# metadata={"help": "share embeddings of (text encoder, text decoder)"},
# )
share_s2t_t2t_embeddings: bool = field(
default=False,
metadata={"help": "share embeddings of (speech2text(code), text2text)"},
)
share_ctc_decoder_embed: bool = field(
default=False,
metadata={"help": "share ctc and decoder embedding (only when share_decoder_input_output_embed is true)"},
)
enc_grad_mult: float = field(
default=1.0,
metadata={"help": "reset feature grad mult in hubert to this (only for st2t)"},
)
retain_dict_path: Optional[str] = field(
default=None,
metadata={"help": "delete embeddings according to this path"},
)
load_step2_model_from: Optional[str] = field(
default=None,
metadata={
"help": "load step2 model from"
},
)
# for other overrides
adaptor_stride: int = field(
default=2,
metadata={"help": "adaptor stride"},
)
@register_model("hubert_st2t", dataclass=HubertSt2tCtcConfig)
class HubertST2T(BaseFairseqModel):
def __init__(self, cfg: HubertSt2tCtcConfig, w2v_encoder: BaseFairseqModel):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
self.num_updates = 0
### in case we need load hubert_step2 model
if cfg.load_step2_model_from:
logger.info(f"Loading hubert_step2 pretrained model for finetuning: {cfg.load_step2_model_from}")
hubert_step2_states = self.w2v_encoder.w2v_model.load_checkpoint(cfg.load_step2_model_from)["model"]
if cfg.retain_dict_path is not None:
with open(cfg.retain_dict_path, "rb") as fp:
overlap_idxs = pickle.load(fp)
if hubert_step2_states['w2v_encoder.w2v_model.decoder.output_projection.0.weight'].size(0) != len(overlap_idxs):
assert self.w2v_encoder.w2v_model.add_text_modality, "Mustc have text modality if retain dict path"
logger.info("Cut embedding to a smaller size according to retain dict")
hubert_step2_states['w2v_encoder.w2v_model.decoder.output_projection.0.weight'] = hubert_step2_states['w2v_encoder.w2v_model.decoder.output_projection.0.weight'][overlap_idxs]
hubert_step2_states["w2v_encoder.w2v_model.decoder.embed_tokens_list.0.weight"] = hubert_step2_states["w2v_encoder.w2v_model.decoder.embed_tokens_list.0.weight"][overlap_idxs]
if hubert_step2_states.get("w2v_encoder.w2v_model.text_encoder.embed_tokens.weight") is not None:
hubert_step2_states["w2v_encoder.w2v_model.text_encoder.embed_tokens.weight"] = hubert_step2_states["w2v_encoder.w2v_model.text_encoder.embed_tokens.weight"][overlap_idxs]
else:
logger.info(f"cfg.load_step2_model_from matches the cut embedding dims {len(overlap_idxs)}, no cutting needs to do")
if not self.cfg.load_speech_only and hubert_step2_states.get("w2v_encoder.w2v_model.text_encoder.embed_tokens.weight", None) is None:
hubert_step2_states["w2v_encoder.w2v_model.text_encoder.embed_tokens.weight"] = hubert_step2_states["w2v_encoder.w2v_model.decoder.embed_tokens_list.0.weight"]
try:
self.load_state_dict(hubert_step2_states, strict=True)
except Exception as e:
logger.warn(e)
self.load_state_dict(hubert_step2_states, strict=False)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: HubertSt2tCtcConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = HubertEncoder(cfg, task.target_dictionary)
return cls(cfg, w2v_encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if "encoder_out" not in net_output:
return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
if "encoder_out_for_ctc" in net_output:
logits = net_output["encoder_out_for_ctc"]
else:
logits = net_output["encoder_out"]
if isinstance(logits, list):
logits = logits[0]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
logits = net_output["encoder_out"]
padding = net_output["encoder_padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float("-inf")
return logits
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
@property
def encoder(self):
return self.w2v_encoder
def reorder_encoder_out(self, encoder_out, new_order):
return self.encoder.reorder_encoder_out(encoder_out, new_order)
@property
def decoder(self):
return self.w2v_encoder.w2v_model.decoder
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
class HubertEncoder(FairseqEncoder):
def __init__(self, cfg: HubertAsrConfig, tgt_dict=None):
self.apply_mask = cfg.apply_mask
logger.info(f"self.apply_mask: {self.apply_mask}")
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"decoder_layerdrop": cfg.decoder_layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
"decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
"share_decoder_input_output_embed": cfg.share_decoder_input_output_embed,
"load_pretrained_w2v_from": cfg.load_pretrained_w2v_from,
"load_pretrained_mbart_from": None,
"adaptor_stride": cfg.adaptor_stride,
"share_speech_text_embeddings": cfg.share_speech_text_embeddings,
}
if cfg.no_pretrained_weights:
arg_overrides["use_rel_pos_enc"] = cfg.use_rel_pos_enc
arg_overrides["encoder_layers"] = cfg.encoder_layers
arg_overrides["add_text_encoder"] = cfg.add_text_encoder
arg_overrides["share_all_embeddings"] = cfg.share_all_embeddings
arg_overrides["add_adaptor"] = cfg.add_adaptor
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
## in speech_text_joint_to_text, data is loaded by soundfile, which returns without normalization
self.need_preprocess = w2v_args.task.normalize
logger.warn("We need normalize the input wavform from the src_tokens")
if cfg.normalize != w2v_args.task.normalize:
logger.warn(
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
if not "share_speech_text_embeddings" in w2v_args.model:
with open_dict(w2v_args.model):
w2v_args.model.share_speech_text_embedding = cfg.share_speech_text_embeddings
logger.info(f"share_speech_text_embeddings: {(getattr(w2v_args.model, 'share_speech_text_embeddings', False))}")
w2v_args.task.data = cfg.data
w2v_args.task.add_decoder = cfg.add_decoder
assert w2v_args.model._name == "hubert"
task = tasks.setup_task(w2v_args.task)
if state is not None and "task_state" in state:
# This will load the stored "dictionaries" object
task.load_state_dict(state["task_state"])
model = task.build_model(w2v_args.model)
### modify the embed_tokens and output_projection of decoder
if state is not None and not cfg.no_pretrained_weights:
model_states = self.modify_states(state['model'], cfg.retain_dict_path, cfg.reuse_text_emb)
try:
model.load_state_dict(model_states, strict=True)
except Exception as e:
logger.warn(e)
model.load_state_dict(model_states, strict=False)
### in case we need load mbart
if cfg.no_pretrained_weights and cfg.load_pretrained_mbart_from:
logger.info("Loading mbart ...")
mbart_state = model.load_checkpoint(cfg.load_pretrained_mbart_from)
mbart_state["model"] = self.modify_states(mbart_state["model"], cfg.retain_dict_path, cfg.reuse_text_emb, is_mbart=True)
model.text_encoder = model.load_pretrained_component_from_model(
component=model.text_encoder, state=mbart_state
)
model.decoder = model.load_pretrained_component_from_model(
component=model.decoder, state=mbart_state
)
model.remove_pretraining_modules(step2=(not cfg.load_speech_only))
# model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = w2v_args.model.encoder_embed_dim
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.freeze_decoder_updates = cfg.freeze_decoder_updates
self.num_updates = 0
self.enc_grad_mult = cfg.enc_grad_mult
def modify_states(self, model_states, retain_dict_path=None, reuse_text_emb=False, is_mbart=False):
if retain_dict_path is not None:
logger.info("Cut embedding to a smaller size according to retain dict")
with open(retain_dict_path, "rb") as fp:
overlap_idxs = pickle.load(fp)
if is_mbart:
model_states["decoder.embed_tokens_list.1.weight"] = model_states["decoder.embed_tokens.weight"][overlap_idxs]
model_states["decoder.output_projection.1.weight"] = model_states["decoder.output_projection.weight"][overlap_idxs]
model_states["decoder.embed_tokens.weight"] = model_states["decoder.embed_tokens.weight"][overlap_idxs]
model_states["decoder.output_projection.weight"] = model_states["decoder.output_projection.weight"][overlap_idxs]
model_states["encoder.embed_tokens.weight"] = model_states["encoder.embed_tokens.weight"][overlap_idxs]
else:
model_states['decoder.output_projection.1.weight'] = model_states['decoder.output_projection.1.weight'][overlap_idxs]
model_states["decoder.embed_tokens_list.1.weight"] = model_states["decoder.embed_tokens_list.1.weight"][overlap_idxs]
model_states["text_encoder.embed_tokens.weight"] = model_states["text_encoder.embed_tokens.weight"][overlap_idxs]
if reuse_text_emb:
logger.info("Loading decoder.embed_tokens_list.0 <-- decoder.embed_tokens_list.1")
model_states["decoder.embed_tokens_list.0.weight"] = model_states["decoder.embed_tokens_list.1.weight"]
model_states["decoder.output_projection.0.weight"] = model_states["decoder.output_projection.1.weight"]
del model_states["decoder.embed_tokens_list.1.weight"]
del model_states["decoder.output_projection.1.weight"]
return model_states
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, src_tokens=None, src_lengths=None, src_txt_tokens=None, src_txt_lengths=None, prev_output_tokens=None, tbc=True, **kwargs):
padding_mask = lengths_to_padding_mask(src_lengths)
if self.need_preprocess:
src_tokens = torch.stack([F.layer_norm(wav, wav.shape) for wav in src_tokens])
src_tokens[padding_mask] = 0.0
ft = self.freeze_finetune_updates <= self.num_updates
w2v_args = {
"source": src_tokens,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
"prev_output_tokens": prev_output_tokens,
"ft": ft,
"enc_grad_mult": self.enc_grad_mult,
}
if self.freeze_decoder_updates <= self.num_updates:
self.w2v_model.add_decoder = True
else:
self.w2v_model.add_decoder = False
x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if src_txt_tokens is not None:
w2v_args_text = {
"src_tokens": src_txt_tokens,
"src_lengths": src_txt_lengths,
"prev_output_tokens": prev_output_tokens,
}
decoder_output_text = self.w2v_model(**w2v_args_text)
decoder_out = (torch.cat([decoder_out[0], decoder_output_text['decoder_out'][0]], dim=0), {"attn_cost": None})
return decoder_out
def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
if isinstance(encoder_out["encoder_out"], list):
encoder_out["encoder_out"] = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
else:
encoder_out["encoder_out"] = encoder_out[
"encoder_out"
].index_select(1, new_order)
if encoder_out["encoder_padding_mask"] is not None:
if isinstance(encoder_out["encoder_padding_mask"], list):
encoder_out["encoder_padding_mask"] = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
else:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
if isinstance(encoder_out["decoder_out"], list):
encoder_out["decoder_out"] = (
[] if len(encoder_out["decoder_out"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
)
else:
encoder_out["decoder_out"] = encoder_out[
"decoder_out"
].index_select(0, new_order)
return encoder_out
def forward_torchscript(self, net_input):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
padding_mask = lengths_to_padding_mask(net_input["src_lengths"])
src_tokens = net_input["src_tokens"]
if self.need_preprocess:
src_tokens = torch.stack([F.layer_norm(wav, wav.shape) for wav in src_tokens])
src_tokens[padding_mask] = 0.0
_net_input = {
"source": src_tokens,
"padding_mask": padding_mask,
}
encoder_out = self.w2v_model.forward_torchscript(_net_input)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m