Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fairseq.models import register_model, register_model_architecture | |
from fairseq.models.transformer import TransformerEncoder, TransformerModel | |
class GRUTransformerModel(TransformerModel): | |
def build_encoder(cls, args, src_dict, embed_tokens): | |
return GRUTransformerEncoder(args, src_dict, embed_tokens) | |
class GRUTransformerEncoder(TransformerEncoder): | |
def __init__(self, args, dictionary, embed_tokens): | |
super().__init__(args, dictionary, embed_tokens) | |
self.emb_ctx = nn.GRU( | |
input_size=embed_tokens.embedding_dim, | |
hidden_size=embed_tokens.embedding_dim // 2, | |
num_layers=1, | |
bidirectional=True, | |
) | |
def forward_embedding(self, src_tokens): | |
# embed tokens and positions | |
x = embed = self.embed_scale * self.embed_tokens(src_tokens) | |
if self.embed_positions is not None: | |
x = embed + self.embed_positions(src_tokens) | |
# contextualize embeddings | |
x = x.transpose(0, 1) | |
x = self.dropout_module(x) | |
x, _ = self.emb_ctx.forward(x) | |
x = x.transpose(0, 1) | |
if self.layernorm_embedding is not None: | |
x = self.layernorm_embedding(x) | |
x = self.dropout_module(x) | |
return x, embed | |
def gru_transformer_base_architecture(args): | |
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) | |
args.encoder_layers = getattr(args, "encoder_layers", 6) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) | |
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.0) | |
args.activation_fn = getattr(args, "activation_fn", "relu") | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", False | |
) | |
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.no_cross_attention = getattr(args, "no_cross_attention", False) | |
args.cross_self_attention = getattr(args, "cross_self_attention", False) | |
args.layer_wise_attention = getattr(args, "layer_wise_attention", False) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | |
def gru_transformer_big(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) | |
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) | |
args.dropout = getattr(args, "dropout", 0.3) | |
gru_transformer_base_architecture(args) | |