|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from fairseq import utils |
|
from fairseq.modules import LayerNorm, MultiheadAttention |
|
from fairseq.modules.fairseq_dropout import FairseqDropout |
|
from fairseq.modules.quant_noise import quant_noise |
|
|
|
|
|
class TransformerSentenceEncoderLayer(nn.Module): |
|
""" |
|
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained |
|
models. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int = 768, |
|
ffn_embedding_dim: int = 3072, |
|
num_attention_heads: int = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
activation_fn: str = "relu", |
|
export: bool = False, |
|
q_noise: float = 0.0, |
|
qn_block_size: int = 8, |
|
init_fn: Callable = None, |
|
) -> None: |
|
super().__init__() |
|
|
|
if init_fn is not None: |
|
init_fn() |
|
|
|
|
|
self.embedding_dim = embedding_dim |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_dropout = attention_dropout |
|
self.q_noise = q_noise |
|
self.qn_block_size = qn_block_size |
|
|
|
self.dropout_module = FairseqDropout( |
|
dropout, module_name=self.__class__.__name__ |
|
) |
|
self.activation_dropout_module = FairseqDropout( |
|
activation_dropout, module_name=self.__class__.__name__ |
|
) |
|
|
|
|
|
self.activation_fn = utils.get_activation_fn(activation_fn) |
|
self.self_attn = self.build_self_attention( |
|
self.embedding_dim, |
|
num_attention_heads, |
|
dropout=attention_dropout, |
|
self_attention=True, |
|
q_noise=q_noise, |
|
qn_block_size=qn_block_size, |
|
) |
|
|
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) |
|
|
|
self.fc1 = self.build_fc1( |
|
self.embedding_dim, |
|
ffn_embedding_dim, |
|
q_noise=q_noise, |
|
qn_block_size=qn_block_size, |
|
) |
|
self.fc2 = self.build_fc2( |
|
ffn_embedding_dim, |
|
self.embedding_dim, |
|
q_noise=q_noise, |
|
qn_block_size=qn_block_size, |
|
) |
|
|
|
|
|
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) |
|
|
|
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) |
|
|
|
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) |
|
|
|
def build_self_attention( |
|
self, |
|
embed_dim, |
|
num_attention_heads, |
|
dropout, |
|
self_attention, |
|
q_noise, |
|
qn_block_size, |
|
): |
|
return MultiheadAttention( |
|
embed_dim, |
|
num_attention_heads, |
|
dropout=dropout, |
|
self_attention=True, |
|
q_noise=q_noise, |
|
qn_block_size=qn_block_size, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
self_attn_mask: Optional[torch.Tensor] = None, |
|
self_attn_padding_mask: Optional[torch.Tensor] = None, |
|
): |
|
""" |
|
LayerNorm is applied either before or after the self-attention/ffn |
|
modules similar to the original Transformer implementation. |
|
""" |
|
residual = x |
|
x, attn = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
need_weights=False, |
|
attn_mask=self_attn_mask, |
|
) |
|
x = self.dropout_module(x) |
|
x = residual + x |
|
x = self.self_attn_layer_norm(x) |
|
|
|
residual = x |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.activation_dropout_module(x) |
|
x = self.fc2(x) |
|
x = self.dropout_module(x) |
|
x = residual + x |
|
x = self.final_layer_norm(x) |
|
return x, attn |
|
|