aliabd
full working demo
d5175d3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch.nn as nn
from fairseq.model_parallel.modules import (
ModelParallelTransformerDecoderLayer,
ModelParallelTransformerEncoderLayer,
)
from fairseq.models import register_model
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
VocabParallelEmbedding,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_transformer")
class ModelParallelTransformerModel(TransformerModel):
"""
Model parallel Transformer model.
"""
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
dictionary.pad_to_multiple_(args.model_parallel_size * 8)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5)
nn.init.constant_(tensor[1], 0)
emb = VocabParallelEmbedding(
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init
)
# if provided, load from preloaded dictionaries
if path:
raise NotImplementedError(
"Loading of embedding from path is not supported for model parallel"
)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return ModelParallelTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
class ModelParallelTransformerEncoder(TransformerEncoder):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
if args.no_final_layer_norm:
self.layer_norm = None
def build_encoder_layer(self, args):
return ModelParallelTransformerEncoderLayer(args)
class ModelParallelTransformerDecoder(TransformerDecoder):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerDecoderLayer`.
"""
def build_decoder_layer(self, args, no_encoder_attn=False):
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if not self.share_input_output_embed:
raise NotImplementedError(
"Model parallel training currently requires --share-decoder-input-output-embed"
)
features = copy_to_model_parallel_region(features)
# project back to size of vocabulary
x = self.output_projection(features)
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy":
x = gather_from_model_parallel_region(x).contiguous()
return x