# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import torch.nn as nn from fairseq.model_parallel.modules import ( ModelParallelTransformerDecoderLayer, ModelParallelTransformerEncoderLayer, ) from fairseq.models import register_model from fairseq.models.transformer import ( TransformerDecoder, TransformerEncoder, TransformerModel, ) try: from fairseq.model_parallel.megatron.mpu import ( copy_to_model_parallel_region, gather_from_model_parallel_region, VocabParallelEmbedding, ) has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False logger = logging.getLogger(__name__) @register_model("model_parallel_transformer") class ModelParallelTransformerModel(TransformerModel): """ Model parallel Transformer model. """ @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) dictionary.pad_to_multiple_(args.model_parallel_size * 8) num_embeddings = len(dictionary) padding_idx = dictionary.pad() def _vocab_init(tensor, **kwargs): nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5) nn.init.constant_(tensor[1], 0) emb = VocabParallelEmbedding( num_embeddings, embed_dim, padding_idx, init_method=_vocab_init ) # if provided, load from preloaded dictionaries if path: raise NotImplementedError( "Loading of embedding from path is not supported for model parallel" ) return emb @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): return ModelParallelTransformerDecoder( args, tgt_dict, embed_tokens, no_encoder_attn=getattr(args, "no_cross_attention", False), ) class ModelParallelTransformerEncoder(TransformerEncoder): """ Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`ModelParallelTransformerEncoderLayer`. """ def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary, embed_tokens) if args.no_final_layer_norm: self.layer_norm = None def build_encoder_layer(self, args): return ModelParallelTransformerEncoderLayer(args) class ModelParallelTransformerDecoder(TransformerDecoder): """ Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`ModelParallelTransformerDecoderLayer`. """ def build_decoder_layer(self, args, no_encoder_attn=False): return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if not self.share_input_output_embed: raise NotImplementedError( "Model parallel training currently requires --share-decoder-input-output-embed" ) features = copy_to_model_parallel_region(features) # project back to size of vocabulary x = self.output_projection(features) if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": x = gather_from_model_parallel_region(x).contiguous() return x