HuBERT / fairseq /models /nat /nonautoregressive_transformer.py
aliabd
full working demo
d5175d3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder
from fairseq.models.transformer import Embedding
from fairseq.modules.transformer_sentence_encoder import init_bert_params
def _mean_pooling(enc_feats, src_masks):
# enc_feats: T x B x C
# src_masks: B x T or None
if src_masks is None:
enc_feats = enc_feats.mean(0)
else:
src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats)
enc_feats = (
(enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None]
).sum(0)
return enc_feats
def _argmax(x, dim):
return (x == x.max(dim, keepdim=True)[0]).type_as(x)
def _uniform_assignment(src_lens, trg_lens):
max_trg_len = trg_lens.max()
steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size
# max_trg_len
index_t = utils.new_arange(trg_lens, max_trg_len).float()
index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len
index_t = torch.round(index_t).long().detach()
return index_t
@register_model("nonautoregressive_transformer")
class NATransformerModel(FairseqNATModel):
@property
def allow_length_beam(self):
return True
@staticmethod
def add_args(parser):
FairseqNATModel.add_args(parser)
# length prediction
parser.add_argument(
"--src-embedding-copy",
action="store_true",
help="copy encoder word embeddings as the initial input of the decoder",
)
parser.add_argument(
"--pred-length-offset",
action="store_true",
help="predicting the length difference between the target and source sentences",
)
parser.add_argument(
"--sg-length-pred",
action="store_true",
help="stop the gradients back-propagated from the length predictor",
)
parser.add_argument(
"--length-loss-factor",
type=float,
help="weights on the length prediction loss",
)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = NATransformerDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
decoder.apply(init_bert_params)
return decoder
def forward(
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
):
# encoding
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
length_out = self.decoder.forward_length(
normalize=False, encoder_out=encoder_out
)
length_tgt = self.decoder.forward_length_prediction(
length_out, encoder_out, tgt_tokens
)
# decoding
word_ins_out = self.decoder(
normalize=False,
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
)
return {
"word_ins": {
"out": word_ins_out,
"tgt": tgt_tokens,
"mask": tgt_tokens.ne(self.pad),
"ls": self.args.label_smoothing,
"nll_loss": True,
},
"length": {
"out": length_out,
"tgt": length_tgt,
"factor": self.decoder.length_loss_factor,
},
}
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out.step
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
history = decoder_out.history
# execute the decoder
output_masks = output_tokens.ne(self.pad)
_scores, _tokens = self.decoder(
normalize=True,
prev_output_tokens=output_tokens,
encoder_out=encoder_out,
step=step,
).max(-1)
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
history=history,
)
def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction
length_tgt = self.decoder.forward_length_prediction(
self.decoder.forward_length(normalize=True, encoder_out=encoder_out),
encoder_out=encoder_out,
)
max_length = length_tgt.clamp_(min=2).max()
idx_length = utils.new_arange(src_tokens, max_length)
initial_output_tokens = src_tokens.new_zeros(
src_tokens.size(0), max_length
).fill_(self.pad)
initial_output_tokens.masked_fill_(
idx_length[None, :] < length_tgt[:, None], self.unk
)
initial_output_tokens[:, 0] = self.bos
initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(encoder_out["encoder_out"][0])
return DecoderOut(
output_tokens=initial_output_tokens,
output_scores=initial_output_scores,
attn=None,
step=0,
max_step=0,
history=None,
)
def regenerate_length_beam(self, decoder_out, beam_size):
output_tokens = decoder_out.output_tokens
length_tgt = output_tokens.ne(self.pad).sum(1)
length_tgt = (
length_tgt[:, None]
+ utils.new_arange(length_tgt, 1, beam_size)
- beam_size // 2
)
length_tgt = length_tgt.view(-1).clamp_(min=2)
max_length = length_tgt.max()
idx_length = utils.new_arange(length_tgt, max_length)
initial_output_tokens = output_tokens.new_zeros(
length_tgt.size(0), max_length
).fill_(self.pad)
initial_output_tokens.masked_fill_(
idx_length[None, :] < length_tgt[:, None], self.unk
)
initial_output_tokens[:, 0] = self.bos
initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(decoder_out.output_scores)
return decoder_out._replace(
output_tokens=initial_output_tokens, output_scores=initial_output_scores
)
class NATransformerDecoder(FairseqNATDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.dictionary = dictionary
self.bos = dictionary.bos()
self.unk = dictionary.unk()
self.eos = dictionary.eos()
self.encoder_embed_dim = args.encoder_embed_dim
self.sg_length_pred = getattr(args, "sg_length_pred", False)
self.pred_length_offset = getattr(args, "pred_length_offset", False)
self.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
self.src_embedding_copy = getattr(args, "src_embedding_copy", False)
self.embed_length = Embedding(256, self.encoder_embed_dim, None)
@ensemble_decoder
def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused):
features, _ = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
embedding_copy=(step == 0) & self.src_embedding_copy,
)
decoder_out = self.output_layer(features)
return F.log_softmax(decoder_out, -1) if normalize else decoder_out
@ensemble_decoder
def forward_length(self, normalize, encoder_out):
enc_feats = encoder_out["encoder_out"][0] # T x B x C
if len(encoder_out["encoder_padding_mask"]) > 0:
src_masks = encoder_out["encoder_padding_mask"][0] # B x T
else:
src_masks = None
enc_feats = _mean_pooling(enc_feats, src_masks)
if self.sg_length_pred:
enc_feats = enc_feats.detach()
length_out = F.linear(enc_feats, self.embed_length.weight)
return F.log_softmax(length_out, -1) if normalize else length_out
def extract_features(
self,
prev_output_tokens,
encoder_out=None,
early_exit=None,
embedding_copy=False,
**unused
):
"""
Similar to *forward* but only return features.
Inputs:
prev_output_tokens: Tensor(B, T)
encoder_out: a dictionary of hidden states and masks
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
the LevenshteinTransformer decoder has full-attention to all generated tokens
"""
# embedding
if embedding_copy:
src_embd = encoder_out["encoder_embedding"][0]
if len(encoder_out["encoder_padding_mask"]) > 0:
src_mask = encoder_out["encoder_padding_mask"][0]
else:
src_mask = None
src_mask = (
~src_mask
if src_mask is not None
else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool()
)
x, decoder_padding_mask = self.forward_embedding(
prev_output_tokens,
self.forward_copying_source(
src_embd, src_mask, prev_output_tokens.ne(self.padding_idx)
),
)
else:
x, decoder_padding_mask = self.forward_embedding(prev_output_tokens)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
inner_states = [x]
# decoder layers
for i, layer in enumerate(self.layers):
# early exit from the decoder.
if (early_exit is not None) and (i >= early_exit):
break
x, attn, _ = layer(
x,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask,
)
inner_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": attn, "inner_states": inner_states}
def forward_embedding(self, prev_output_tokens, states=None):
# embed positions
positions = (
self.embed_positions(prev_output_tokens)
if self.embed_positions is not None
else None
)
# embed tokens and positions
if states is None:
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
else:
x = states
if positions is not None:
x += positions
x = self.dropout_module(x)
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
return x, decoder_padding_mask
def forward_copying_source(self, src_embeds, src_masks, tgt_masks):
length_sources = src_masks.sum(1)
length_targets = tgt_masks.sum(1)
mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill(
~tgt_masks, 0
)
copied_embedding = torch.gather(
src_embeds,
1,
mapped_inputs.unsqueeze(-1).expand(
*mapped_inputs.size(), src_embeds.size(-1)
),
)
return copied_embedding
def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None):
enc_feats = encoder_out["encoder_out"][0] # T x B x C
if len(encoder_out["encoder_padding_mask"]) > 0:
src_masks = encoder_out["encoder_padding_mask"][0] # B x T
else:
src_masks = None
if self.pred_length_offset:
if src_masks is None:
src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_(
enc_feats.size(0)
)
else:
src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0)
src_lengs = src_lengs.long()
if tgt_tokens is not None:
# obtain the length target
tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long()
if self.pred_length_offset:
length_tgt = tgt_lengs - src_lengs + 128
else:
length_tgt = tgt_lengs
length_tgt = length_tgt.clamp(min=0, max=255)
else:
# predict the length target (greedy for now)
# TODO: implementing length-beam
pred_lengs = length_out.max(-1)[1]
if self.pred_length_offset:
length_tgt = pred_lengs - 128 + src_lengs
else:
length_tgt = pred_lengs
return length_tgt
@register_model_architecture(
"nonautoregressive_transformer", "nonautoregressive_transformer"
)
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.apply_bert_init = getattr(args, "apply_bert_init", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
# --- special arguments ---
args.sg_length_pred = getattr(args, "sg_length_pred", False)
args.pred_length_offset = getattr(args, "pred_length_offset", False)
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
@register_model_architecture(
"nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de"
)
def nonautoregressive_transformer_wmt_en_de(args):
base_architecture(args)