# -------------------------------------------------------- # The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777) # Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on fairseq code bases # https://github.com/facebookresearch/fairseq # -------------------------------------------------------- """ Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_layer.py https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_decoder_layer.py """ from typing import Dict, List, Optional import torch from torch import Tensor from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase from fairseq.modules import LayerNorm from yitrans_iwslt22.modules.multihead_attention import MultiheadAttention class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *cfg.decoder.normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False ): super().__init__( cfg, no_encoder_attn, add_bias_kv, add_zero_attn, ) if has_relative_attention_bias: self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads) def build_self_attention( self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False ): return MultiheadAttention( embed_dim, cfg.decoder.attention_heads, dropout=cfg.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not cfg.cross_self_attention, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, ) def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, pos_bias=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if pos_bias is not None: pos_bias = self.norm_k(pos_bias) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, position_bias=pos_bias, ) if self.c_attn is not None: tgt_len, bsz = x.size(0), x.size(1) x = x.view(tgt_len, bsz, self.nh, self.head_dim) x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) x = x.reshape(tgt_len, bsz, self.embed_dim) if self.attn_ln is not None: x = self.attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None and encoder_out is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) if self.w_resid is not None: residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn