|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq import utils |
|
from fairseq.iterative_refinement_generator import DecoderOut |
|
from fairseq.models import register_model, register_model_architecture |
|
from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder |
|
from fairseq.models.transformer import Embedding |
|
from fairseq.modules.transformer_sentence_encoder import init_bert_params |
|
|
|
|
|
def _mean_pooling(enc_feats, src_masks): |
|
|
|
|
|
if src_masks is None: |
|
enc_feats = enc_feats.mean(0) |
|
else: |
|
src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats) |
|
enc_feats = ( |
|
(enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None] |
|
).sum(0) |
|
return enc_feats |
|
|
|
|
|
def _argmax(x, dim): |
|
return (x == x.max(dim, keepdim=True)[0]).type_as(x) |
|
|
|
|
|
def _uniform_assignment(src_lens, trg_lens): |
|
max_trg_len = trg_lens.max() |
|
steps = (src_lens.float() - 1) / (trg_lens.float() - 1) |
|
|
|
index_t = utils.new_arange(trg_lens, max_trg_len).float() |
|
index_t = steps[:, None] * index_t[None, :] |
|
index_t = torch.round(index_t).long().detach() |
|
return index_t |
|
|
|
|
|
@register_model("nonautoregressive_transformer") |
|
class NATransformerModel(FairseqNATModel): |
|
@property |
|
def allow_length_beam(self): |
|
return True |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
FairseqNATModel.add_args(parser) |
|
|
|
|
|
parser.add_argument( |
|
"--src-embedding-copy", |
|
action="store_true", |
|
help="copy encoder word embeddings as the initial input of the decoder", |
|
) |
|
parser.add_argument( |
|
"--pred-length-offset", |
|
action="store_true", |
|
help="predicting the length difference between the target and source sentences", |
|
) |
|
parser.add_argument( |
|
"--sg-length-pred", |
|
action="store_true", |
|
help="stop the gradients back-propagated from the length predictor", |
|
) |
|
parser.add_argument( |
|
"--length-loss-factor", |
|
type=float, |
|
help="weights on the length prediction loss", |
|
) |
|
|
|
@classmethod |
|
def build_decoder(cls, args, tgt_dict, embed_tokens): |
|
decoder = NATransformerDecoder(args, tgt_dict, embed_tokens) |
|
if getattr(args, "apply_bert_init", False): |
|
decoder.apply(init_bert_params) |
|
return decoder |
|
|
|
def forward( |
|
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs |
|
): |
|
|
|
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) |
|
|
|
|
|
length_out = self.decoder.forward_length( |
|
normalize=False, encoder_out=encoder_out |
|
) |
|
length_tgt = self.decoder.forward_length_prediction( |
|
length_out, encoder_out, tgt_tokens |
|
) |
|
|
|
|
|
word_ins_out = self.decoder( |
|
normalize=False, |
|
prev_output_tokens=prev_output_tokens, |
|
encoder_out=encoder_out, |
|
) |
|
|
|
return { |
|
"word_ins": { |
|
"out": word_ins_out, |
|
"tgt": tgt_tokens, |
|
"mask": tgt_tokens.ne(self.pad), |
|
"ls": self.args.label_smoothing, |
|
"nll_loss": True, |
|
}, |
|
"length": { |
|
"out": length_out, |
|
"tgt": length_tgt, |
|
"factor": self.decoder.length_loss_factor, |
|
}, |
|
} |
|
|
|
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): |
|
step = decoder_out.step |
|
output_tokens = decoder_out.output_tokens |
|
output_scores = decoder_out.output_scores |
|
history = decoder_out.history |
|
|
|
|
|
output_masks = output_tokens.ne(self.pad) |
|
_scores, _tokens = self.decoder( |
|
normalize=True, |
|
prev_output_tokens=output_tokens, |
|
encoder_out=encoder_out, |
|
step=step, |
|
).max(-1) |
|
|
|
output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) |
|
output_scores.masked_scatter_(output_masks, _scores[output_masks]) |
|
if history is not None: |
|
history.append(output_tokens.clone()) |
|
|
|
return decoder_out._replace( |
|
output_tokens=output_tokens, |
|
output_scores=output_scores, |
|
attn=None, |
|
history=history, |
|
) |
|
|
|
def initialize_output_tokens(self, encoder_out, src_tokens): |
|
|
|
length_tgt = self.decoder.forward_length_prediction( |
|
self.decoder.forward_length(normalize=True, encoder_out=encoder_out), |
|
encoder_out=encoder_out, |
|
) |
|
|
|
max_length = length_tgt.clamp_(min=2).max() |
|
idx_length = utils.new_arange(src_tokens, max_length) |
|
|
|
initial_output_tokens = src_tokens.new_zeros( |
|
src_tokens.size(0), max_length |
|
).fill_(self.pad) |
|
initial_output_tokens.masked_fill_( |
|
idx_length[None, :] < length_tgt[:, None], self.unk |
|
) |
|
initial_output_tokens[:, 0] = self.bos |
|
initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) |
|
|
|
initial_output_scores = initial_output_tokens.new_zeros( |
|
*initial_output_tokens.size() |
|
).type_as(encoder_out["encoder_out"][0]) |
|
|
|
return DecoderOut( |
|
output_tokens=initial_output_tokens, |
|
output_scores=initial_output_scores, |
|
attn=None, |
|
step=0, |
|
max_step=0, |
|
history=None, |
|
) |
|
|
|
def regenerate_length_beam(self, decoder_out, beam_size): |
|
output_tokens = decoder_out.output_tokens |
|
length_tgt = output_tokens.ne(self.pad).sum(1) |
|
length_tgt = ( |
|
length_tgt[:, None] |
|
+ utils.new_arange(length_tgt, 1, beam_size) |
|
- beam_size // 2 |
|
) |
|
length_tgt = length_tgt.view(-1).clamp_(min=2) |
|
max_length = length_tgt.max() |
|
idx_length = utils.new_arange(length_tgt, max_length) |
|
|
|
initial_output_tokens = output_tokens.new_zeros( |
|
length_tgt.size(0), max_length |
|
).fill_(self.pad) |
|
initial_output_tokens.masked_fill_( |
|
idx_length[None, :] < length_tgt[:, None], self.unk |
|
) |
|
initial_output_tokens[:, 0] = self.bos |
|
initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) |
|
|
|
initial_output_scores = initial_output_tokens.new_zeros( |
|
*initial_output_tokens.size() |
|
).type_as(decoder_out.output_scores) |
|
|
|
return decoder_out._replace( |
|
output_tokens=initial_output_tokens, output_scores=initial_output_scores |
|
) |
|
|
|
|
|
class NATransformerDecoder(FairseqNATDecoder): |
|
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): |
|
super().__init__( |
|
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn |
|
) |
|
self.dictionary = dictionary |
|
self.bos = dictionary.bos() |
|
self.unk = dictionary.unk() |
|
self.eos = dictionary.eos() |
|
|
|
self.encoder_embed_dim = args.encoder_embed_dim |
|
self.sg_length_pred = getattr(args, "sg_length_pred", False) |
|
self.pred_length_offset = getattr(args, "pred_length_offset", False) |
|
self.length_loss_factor = getattr(args, "length_loss_factor", 0.1) |
|
self.src_embedding_copy = getattr(args, "src_embedding_copy", False) |
|
self.embed_length = Embedding(256, self.encoder_embed_dim, None) |
|
|
|
@ensemble_decoder |
|
def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused): |
|
features, _ = self.extract_features( |
|
prev_output_tokens, |
|
encoder_out=encoder_out, |
|
embedding_copy=(step == 0) & self.src_embedding_copy, |
|
) |
|
decoder_out = self.output_layer(features) |
|
return F.log_softmax(decoder_out, -1) if normalize else decoder_out |
|
|
|
@ensemble_decoder |
|
def forward_length(self, normalize, encoder_out): |
|
enc_feats = encoder_out["encoder_out"][0] |
|
if len(encoder_out["encoder_padding_mask"]) > 0: |
|
src_masks = encoder_out["encoder_padding_mask"][0] |
|
else: |
|
src_masks = None |
|
enc_feats = _mean_pooling(enc_feats, src_masks) |
|
if self.sg_length_pred: |
|
enc_feats = enc_feats.detach() |
|
length_out = F.linear(enc_feats, self.embed_length.weight) |
|
return F.log_softmax(length_out, -1) if normalize else length_out |
|
|
|
def extract_features( |
|
self, |
|
prev_output_tokens, |
|
encoder_out=None, |
|
early_exit=None, |
|
embedding_copy=False, |
|
**unused |
|
): |
|
""" |
|
Similar to *forward* but only return features. |
|
|
|
Inputs: |
|
prev_output_tokens: Tensor(B, T) |
|
encoder_out: a dictionary of hidden states and masks |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)` |
|
- a dictionary with any model-specific outputs |
|
the LevenshteinTransformer decoder has full-attention to all generated tokens |
|
""" |
|
|
|
if embedding_copy: |
|
src_embd = encoder_out["encoder_embedding"][0] |
|
if len(encoder_out["encoder_padding_mask"]) > 0: |
|
src_mask = encoder_out["encoder_padding_mask"][0] |
|
else: |
|
src_mask = None |
|
src_mask = ( |
|
~src_mask |
|
if src_mask is not None |
|
else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool() |
|
) |
|
|
|
x, decoder_padding_mask = self.forward_embedding( |
|
prev_output_tokens, |
|
self.forward_copying_source( |
|
src_embd, src_mask, prev_output_tokens.ne(self.padding_idx) |
|
), |
|
) |
|
|
|
else: |
|
|
|
x, decoder_padding_mask = self.forward_embedding(prev_output_tokens) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
attn = None |
|
inner_states = [x] |
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
|
|
|
|
if (early_exit is not None) and (i >= early_exit): |
|
break |
|
|
|
x, attn, _ = layer( |
|
x, |
|
encoder_out["encoder_out"][0] |
|
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) |
|
else None, |
|
encoder_out["encoder_padding_mask"][0] |
|
if ( |
|
encoder_out is not None |
|
and len(encoder_out["encoder_padding_mask"]) > 0 |
|
) |
|
else None, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=decoder_padding_mask, |
|
) |
|
inner_states.append(x) |
|
|
|
if self.layer_norm: |
|
x = self.layer_norm(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if self.project_out_dim is not None: |
|
x = self.project_out_dim(x) |
|
|
|
return x, {"attn": attn, "inner_states": inner_states} |
|
|
|
def forward_embedding(self, prev_output_tokens, states=None): |
|
|
|
positions = ( |
|
self.embed_positions(prev_output_tokens) |
|
if self.embed_positions is not None |
|
else None |
|
) |
|
|
|
|
|
if states is None: |
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
if self.project_in_dim is not None: |
|
x = self.project_in_dim(x) |
|
else: |
|
x = states |
|
|
|
if positions is not None: |
|
x += positions |
|
x = self.dropout_module(x) |
|
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) |
|
return x, decoder_padding_mask |
|
|
|
def forward_copying_source(self, src_embeds, src_masks, tgt_masks): |
|
length_sources = src_masks.sum(1) |
|
length_targets = tgt_masks.sum(1) |
|
mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill( |
|
~tgt_masks, 0 |
|
) |
|
copied_embedding = torch.gather( |
|
src_embeds, |
|
1, |
|
mapped_inputs.unsqueeze(-1).expand( |
|
*mapped_inputs.size(), src_embeds.size(-1) |
|
), |
|
) |
|
return copied_embedding |
|
|
|
def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None): |
|
enc_feats = encoder_out["encoder_out"][0] |
|
if len(encoder_out["encoder_padding_mask"]) > 0: |
|
src_masks = encoder_out["encoder_padding_mask"][0] |
|
else: |
|
src_masks = None |
|
if self.pred_length_offset: |
|
if src_masks is None: |
|
src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( |
|
enc_feats.size(0) |
|
) |
|
else: |
|
src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0) |
|
src_lengs = src_lengs.long() |
|
|
|
if tgt_tokens is not None: |
|
|
|
tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long() |
|
if self.pred_length_offset: |
|
length_tgt = tgt_lengs - src_lengs + 128 |
|
else: |
|
length_tgt = tgt_lengs |
|
length_tgt = length_tgt.clamp(min=0, max=255) |
|
|
|
else: |
|
|
|
|
|
pred_lengs = length_out.max(-1)[1] |
|
if self.pred_length_offset: |
|
length_tgt = pred_lengs - 128 + src_lengs |
|
else: |
|
length_tgt = pred_lengs |
|
|
|
return length_tgt |
|
|
|
|
|
@register_model_architecture( |
|
"nonautoregressive_transformer", "nonautoregressive_transformer" |
|
) |
|
def base_architecture(args): |
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) |
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) |
|
args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) |
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) |
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) |
|
args.decoder_ffn_embed_dim = getattr( |
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim |
|
) |
|
args.decoder_layers = getattr(args, "decoder_layers", 6) |
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) |
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) |
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) |
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0) |
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0) |
|
args.activation_fn = getattr(args, "activation_fn", "relu") |
|
args.dropout = getattr(args, "dropout", 0.1) |
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) |
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) |
|
args.share_decoder_input_output_embed = getattr( |
|
args, "share_decoder_input_output_embed", False |
|
) |
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) |
|
args.no_token_positional_embeddings = getattr( |
|
args, "no_token_positional_embeddings", False |
|
) |
|
args.adaptive_input = getattr(args, "adaptive_input", False) |
|
args.apply_bert_init = getattr(args, "apply_bert_init", False) |
|
|
|
args.decoder_output_dim = getattr( |
|
args, "decoder_output_dim", args.decoder_embed_dim |
|
) |
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) |
|
|
|
|
|
args.sg_length_pred = getattr(args, "sg_length_pred", False) |
|
args.pred_length_offset = getattr(args, "pred_length_offset", False) |
|
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) |
|
args.src_embedding_copy = getattr(args, "src_embedding_copy", False) |
|
|
|
|
|
@register_model_architecture( |
|
"nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de" |
|
) |
|
def nonautoregressive_transformer_wmt_en_de(args): |
|
base_architecture(args) |
|
|