|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from fairseq.modules import TransformerSentenceEncoder |
|
from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( |
|
SparseTransformerSentenceEncoderLayer, |
|
) |
|
|
|
|
|
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): |
|
""" |
|
Sparse implementation of the TransformerSentenceEncoder |
|
- see SparseMultiheadAttention |
|
""" |
|
|
|
def __init__( |
|
self, |
|
padding_idx: int, |
|
vocab_size: int, |
|
num_encoder_layers: int = 6, |
|
embedding_dim: int = 768, |
|
ffn_embedding_dim: int = 3072, |
|
num_attention_heads: int = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
max_seq_len: int = 256, |
|
num_segments: int = 2, |
|
use_position_embeddings: bool = True, |
|
offset_positions_by_padding: bool = True, |
|
encoder_normalize_before: bool = False, |
|
apply_bert_init: bool = False, |
|
activation_fn: str = "relu", |
|
learned_pos_embedding: bool = True, |
|
embed_scale: float = None, |
|
freeze_embeddings: bool = False, |
|
n_trans_layers_to_freeze: int = 0, |
|
export: bool = False, |
|
is_bidirectional: bool = True, |
|
stride: int = 32, |
|
expressivity: int = 8, |
|
) -> None: |
|
|
|
super().__init__( |
|
padding_idx, |
|
vocab_size, |
|
num_encoder_layers, |
|
embedding_dim, |
|
ffn_embedding_dim, |
|
num_attention_heads, |
|
dropout, |
|
attention_dropout, |
|
activation_dropout, |
|
max_seq_len, |
|
num_segments, |
|
use_position_embeddings, |
|
offset_positions_by_padding, |
|
encoder_normalize_before, |
|
apply_bert_init, |
|
activation_fn, |
|
learned_pos_embedding, |
|
embed_scale, |
|
freeze_embeddings, |
|
n_trans_layers_to_freeze, |
|
export, |
|
) |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
SparseTransformerSentenceEncoderLayer( |
|
embedding_dim=self.embedding_dim, |
|
ffn_embedding_dim=ffn_embedding_dim, |
|
num_attention_heads=num_attention_heads, |
|
dropout=dropout, |
|
attention_dropout=attention_dropout, |
|
activation_dropout=activation_dropout, |
|
activation_fn=activation_fn, |
|
export=export, |
|
is_bidirectional=is_bidirectional, |
|
stride=stride, |
|
expressivity=expressivity, |
|
) |
|
for _ in range(num_encoder_layers) |
|
] |
|
) |
|
|
|
def freeze_module_params(m): |
|
if m is not None: |
|
for p in m.parameters(): |
|
p.requires_grad = False |
|
|
|
for layer in range(n_trans_layers_to_freeze): |
|
freeze_module_params(self.layers[layer]) |
|
|