|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from fairseq.modules import ( |
|
FairseqDropout, |
|
LayerDropModuleList, |
|
LayerNorm, |
|
MultiheadAttention, |
|
PositionalEmbedding, |
|
TransformerSentenceEncoderLayer, |
|
) |
|
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ |
|
|
|
|
|
def init_bert_params(module): |
|
""" |
|
Initialize the weights specific to the BERT Model. |
|
This overrides the default initializations depending on the specified arguments. |
|
1. If normal_init_linear_weights is set then weights of linear |
|
layer will be initialized using the normal distribution and |
|
bais will be set to the specified value. |
|
2. If normal_init_embed_weights is set then weights of embedding |
|
layer will be initialized using the normal distribution. |
|
3. If normal_init_proj_weights is set then weights of |
|
in_project_weight for MultiHeadAttention initialized using |
|
the normal distribution (to be validated). |
|
""" |
|
|
|
def normal_(data): |
|
|
|
|
|
data.copy_( |
|
data.cpu().normal_(mean=0.0, std=0.02).to(data.device) |
|
) |
|
|
|
if isinstance(module, nn.Linear): |
|
normal_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, nn.Embedding): |
|
normal_(module.weight.data) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
if isinstance(module, MultiheadAttention): |
|
normal_(module.q_proj.weight.data) |
|
normal_(module.k_proj.weight.data) |
|
normal_(module.v_proj.weight.data) |
|
|
|
|
|
class TransformerSentenceEncoder(nn.Module): |
|
""" |
|
Implementation for a Bi-directional Transformer based Sentence Encoder used |
|
in BERT/XLM style pre-trained models. |
|
|
|
This first computes the token embedding using the token embedding matrix, |
|
position embeddings (if specified) and segment embeddings |
|
(if specified). After applying the specified number of |
|
TransformerEncoderLayers, it outputs all the internal states of the |
|
encoder as well as the final representation associated with the first |
|
token (usually CLS token). |
|
|
|
Input: |
|
- tokens: B x T matrix representing sentences |
|
- segment_labels: B x T matrix representing segment label for tokens |
|
|
|
Output: |
|
- a tuple of the following: |
|
- a list of internal model states used to compute the |
|
predictions where each tensor has shape T x B x C |
|
- sentence representation associated with first input token |
|
in format B x C. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
padding_idx: int, |
|
vocab_size: int, |
|
num_encoder_layers: int = 6, |
|
embedding_dim: int = 768, |
|
ffn_embedding_dim: int = 3072, |
|
num_attention_heads: int = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
layerdrop: float = 0.0, |
|
max_seq_len: int = 256, |
|
num_segments: int = 2, |
|
use_position_embeddings: bool = True, |
|
offset_positions_by_padding: bool = True, |
|
encoder_normalize_before: bool = False, |
|
apply_bert_init: bool = False, |
|
activation_fn: str = "relu", |
|
learned_pos_embedding: bool = True, |
|
embed_scale: float = None, |
|
freeze_embeddings: bool = False, |
|
n_trans_layers_to_freeze: int = 0, |
|
export: bool = False, |
|
traceable: bool = False, |
|
q_noise: float = 0.0, |
|
qn_block_size: int = 8, |
|
) -> None: |
|
|
|
super().__init__() |
|
self.padding_idx = padding_idx |
|
self.vocab_size = vocab_size |
|
self.dropout_module = FairseqDropout( |
|
dropout, module_name=self.__class__.__name__ |
|
) |
|
self.layerdrop = layerdrop |
|
self.max_seq_len = max_seq_len |
|
self.embedding_dim = embedding_dim |
|
self.num_segments = num_segments |
|
self.use_position_embeddings = use_position_embeddings |
|
self.apply_bert_init = apply_bert_init |
|
self.learned_pos_embedding = learned_pos_embedding |
|
self.traceable = traceable |
|
|
|
self.embed_tokens = self.build_embedding( |
|
self.vocab_size, self.embedding_dim, self.padding_idx |
|
) |
|
self.embed_scale = embed_scale |
|
|
|
if q_noise > 0: |
|
self.quant_noise = apply_quant_noise_( |
|
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), |
|
q_noise, |
|
qn_block_size, |
|
) |
|
else: |
|
self.quant_noise = None |
|
|
|
self.segment_embeddings = ( |
|
nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) |
|
if self.num_segments > 0 |
|
else None |
|
) |
|
|
|
self.embed_positions = ( |
|
PositionalEmbedding( |
|
self.max_seq_len, |
|
self.embedding_dim, |
|
padding_idx=(self.padding_idx if offset_positions_by_padding else None), |
|
learned=self.learned_pos_embedding, |
|
) |
|
if self.use_position_embeddings |
|
else None |
|
) |
|
|
|
if encoder_normalize_before: |
|
self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) |
|
else: |
|
self.emb_layer_norm = None |
|
|
|
if self.layerdrop > 0.0: |
|
self.layers = LayerDropModuleList(p=self.layerdrop) |
|
else: |
|
self.layers = nn.ModuleList([]) |
|
self.layers.extend( |
|
[ |
|
self.build_transformer_sentence_encoder_layer( |
|
embedding_dim=self.embedding_dim, |
|
ffn_embedding_dim=ffn_embedding_dim, |
|
num_attention_heads=num_attention_heads, |
|
dropout=self.dropout_module.p, |
|
attention_dropout=attention_dropout, |
|
activation_dropout=activation_dropout, |
|
activation_fn=activation_fn, |
|
export=export, |
|
q_noise=q_noise, |
|
qn_block_size=qn_block_size, |
|
) |
|
for _ in range(num_encoder_layers) |
|
] |
|
) |
|
|
|
|
|
if self.apply_bert_init: |
|
self.apply(init_bert_params) |
|
|
|
def freeze_module_params(m): |
|
if m is not None: |
|
for p in m.parameters(): |
|
p.requires_grad = False |
|
|
|
if freeze_embeddings: |
|
freeze_module_params(self.embed_tokens) |
|
freeze_module_params(self.segment_embeddings) |
|
freeze_module_params(self.embed_positions) |
|
freeze_module_params(self.emb_layer_norm) |
|
|
|
for layer in range(n_trans_layers_to_freeze): |
|
freeze_module_params(self.layers[layer]) |
|
|
|
def build_embedding(self, vocab_size, embedding_dim, padding_idx): |
|
return nn.Embedding(vocab_size, embedding_dim, padding_idx) |
|
|
|
def build_transformer_sentence_encoder_layer( |
|
self, |
|
embedding_dim, |
|
ffn_embedding_dim, |
|
num_attention_heads, |
|
dropout, |
|
attention_dropout, |
|
activation_dropout, |
|
activation_fn, |
|
export, |
|
q_noise, |
|
qn_block_size, |
|
): |
|
return TransformerSentenceEncoderLayer( |
|
embedding_dim=embedding_dim, |
|
ffn_embedding_dim=ffn_embedding_dim, |
|
num_attention_heads=num_attention_heads, |
|
dropout=dropout, |
|
attention_dropout=attention_dropout, |
|
activation_dropout=activation_dropout, |
|
activation_fn=activation_fn, |
|
export=export, |
|
q_noise=q_noise, |
|
qn_block_size=qn_block_size, |
|
) |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
segment_labels: torch.Tensor = None, |
|
last_state_only: bool = False, |
|
positions: Optional[torch.Tensor] = None, |
|
token_embeddings: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
is_tpu = tokens.device.type == "xla" |
|
|
|
|
|
padding_mask = tokens.eq(self.padding_idx) |
|
if not self.traceable and not is_tpu and not padding_mask.any(): |
|
padding_mask = None |
|
|
|
if token_embeddings is not None: |
|
x = token_embeddings |
|
else: |
|
x = self.embed_tokens(tokens) |
|
|
|
if self.embed_scale is not None: |
|
x = x * self.embed_scale |
|
|
|
if self.embed_positions is not None: |
|
x = x + self.embed_positions(tokens, positions=positions) |
|
|
|
if self.segment_embeddings is not None and segment_labels is not None: |
|
x = x + self.segment_embeddings(segment_labels) |
|
|
|
if self.quant_noise is not None: |
|
x = self.quant_noise(x) |
|
|
|
if self.emb_layer_norm is not None: |
|
x = self.emb_layer_norm(x) |
|
|
|
x = self.dropout_module(x) |
|
|
|
|
|
if padding_mask is not None: |
|
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
inner_states = [] |
|
if not last_state_only: |
|
inner_states.append(x) |
|
|
|
for layer in self.layers: |
|
x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) |
|
if not last_state_only: |
|
inner_states.append(x) |
|
|
|
sentence_rep = x[0, :, :] |
|
|
|
if last_state_only: |
|
inner_states = [x] |
|
|
|
if self.traceable: |
|
return torch.stack(inner_states), sentence_rep |
|
else: |
|
return inner_states, sentence_rep |
|
|