# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F from fairseq import options, utils from fairseq.modules import ( AdaptiveSoftmax, LayerNorm, MultiheadAttention, PositionalEmbedding, ) EncoderOut = namedtuple( "TransformerEncoderOut", [ "encoder_out", # T x B x C "encoder_padding_mask", # B x T "encoder_embedding", # B x T x C "encoder_states", # List[T x B x C] ], ) class TransformerEncoderEmbedding(nn.Module): """ Encoder Embedding + Positional Embedding """ def __init__(self, args, embed_tokens): super().__init__() self.dropout = args.dropout self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens if isinstance(embed_tokens, nn.ModuleList): self.padding_idx = embed_tokens[0].padding_idx embed_dim = sum(e.embedding_dim for e in embed_tokens) else: self.padding_idx = embed_tokens.padding_idx embed_dim = embed_tokens.embedding_dim self.embed_scale = math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None def forward(self, input): # embed tokens and positions src_tokens = input[0] prev_output_tokens = input[2] if isinstance(self.embed_tokens, nn.ModuleList): x_embed_list = [] for embed_tokens_part in self.embed_tokens: x_embed_list.append(embed_tokens_part(src_tokens)) embedded = torch.cat(x_embed_list, dim=-1) else: embedded = self.embed_tokens(src_tokens) x = embed = self.embed_scale * embedded if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) return (x, encoder_padding_mask, prev_output_tokens) class TransformerEncoderLayerNorm(nn.Module): """ Layer norm at the the end of all encoder layers if args.encoder_enormalize_before = True """ def __init__(self, args, embed_dim): super().__init__() if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def forward(self, input): x = input[0] encoder_padding_mask = input[1] prev_output_tokens = input[2] if self.layer_norm: x = self.layer_norm(x) # keeping track of the incremental_state is not supported yet return (x, encoder_padding_mask, prev_output_tokens) class TransformerDecoderEmbedding(nn.Module): """ Decoder Embedding + Positional Embedding """ def __init__(self, args, embed_tokens): super().__init__() self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = ( sum(e.embedding_dim for e in embed_tokens) if isinstance(embed_tokens, nn.ModuleList) else embed_tokens.embedding_dim ) embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim padding_idx = ( embed_tokens[0].padding_idx if isinstance(embed_tokens, nn.ModuleList) else embed_tokens.padding_idx ) self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim self.project_in_dim = ( Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None ) self.embed_positions = ( PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) def forward(self, input): mt_task = False if isinstance(input, tuple): if len(input) == 3: encoder_out = input[0] encoder_padding_mask = input[1] prev_output_tokens = input[2] incremental_state = None # Hardcoding to avoid passing of None objects mt_task = True else: # HACK for now, need to fix (TODO sidgoyal) prev_output_tokens = input[0] # discard "src_lengths" encoder_out = None encoder_padding_mask = None incremental_state = None else: prev_output_tokens = input encoder_out = None encoder_padding_mask = None incremental_state = None positions = ( self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions if isinstance(self.embed_tokens, nn.ModuleList): x_embed_list = [] for embed_tokens_part in self.embed_tokens: x_embed_list.append(embed_tokens_part(prev_output_tokens)) x = self.embed_scale * torch.cat(x_embed_list, dim=-1) else: x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) if mt_task: return (x, encoder_out, encoder_padding_mask) return x class TransformerDecoderOutputLayer(nn.Module): def __init__(self, args, embed_tokens, dictionary): super().__init__() self.share_input_output_embed = args.share_decoder_input_output_embed self.embed_tokens = embed_tokens self.output_embed_dim = args.decoder_output_dim embed_dim = args.decoder_embed_dim self.project_out_dim = ( Linear(embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None ) self.adaptive_softmax = None if args.adaptive_softmax_cutoff is not None: assert not isinstance(embed_tokens, nn.ModuleList) self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_tokens = nn.Parameter( torch.Tensor(len(dictionary), self.output_embed_dim) ) nn.init.normal_( self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5 ) if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False ): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def forward(self, input, apply_final_proj=True): if isinstance(input, tuple): x = input[0] else: x = input if self.layer_norm: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) if apply_final_proj: x = self.output_layer(x) return x def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: if isinstance(self.embed_tokens, nn.ModuleList): output = None for i, emb in enumerate(self.embed_tokens): sidx = i * emb.embedding_dim eidx = (i + 1) * emb.embedding_dim if output is None: output = F.linear(features[:, :, sidx:eidx], emb.weight) else: output += F.linear(features[:, :, sidx:eidx], emb.weight) return output else: return F.linear(features, self.embed_tokens.weight) else: return F.linear(features, self.embed_tokens) else: return features class TransformerEncoderLayer(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.encoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] def forward(self, input): """ Args: input (Tuple): input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` input[1] (ByteTensor/FloatTensor): encoder padding mask - binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. input[2] (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing) Returns: output (Tuple): output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` output[1] (ByteTensor/FloatTensor): encoder padding mask output[2] (LongTensor): previous decoder outputs """ x = input[0] encoder_padding_mask = input[1] prev_output_tokens = input[2] residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) return (x, encoder_padding_mask, prev_output_tokens) def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, input): """ Args: input (Tuple): input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` input[2] (ByteTensor/FloatTensor): encoder padding mask - binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: output (Tuple): output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` output[1] (ByteTensor/FloatTensor): encoder padding mask output[2] (LongTensor): previous decoder outputs """ # Note: incremental state is not yet supported mt_task = False if isinstance(input, tuple): x = input[0] encoder_out = input[1] encoder_padding_mask = input[2] incremental_state = None mt_task = True else: x = input encoder_out = None encoder_padding_mask = None incremental_state = None if incremental_state is None: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None # TODO: add back prev_self_attn_state, prev_attn_state, # self_attn_padding_mask prev_self_attn_state = None prev_attn_state = None self_attn_padding_mask = None residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if mt_task: return (x, encoder_out, encoder_padding_mask) return x def buffered_future_mask(self, tensor): dim = tensor.size(0) if ( not hasattr(self, "_future_mask") or self._future_mask is None or self._future_mask.device != tensor.device ): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 ) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 ) return self._future_mask[:dim, :dim] def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.constant_(m.weight[padding_idx], 0) return m def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: nn.init.constant_(m.bias, 0.0) return m