|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import torch.nn as nn |
|
from fairseq.model_parallel.modules import ( |
|
ModelParallelTransformerDecoderLayer, |
|
ModelParallelTransformerEncoderLayer, |
|
) |
|
from fairseq.models import register_model |
|
from fairseq.models.transformer import ( |
|
TransformerDecoder, |
|
TransformerEncoder, |
|
TransformerModel, |
|
) |
|
|
|
|
|
try: |
|
from fairseq.model_parallel.megatron.mpu import ( |
|
copy_to_model_parallel_region, |
|
gather_from_model_parallel_region, |
|
VocabParallelEmbedding, |
|
) |
|
|
|
has_megatron_submodule = True |
|
except (ImportError, ModuleNotFoundError): |
|
has_megatron_submodule = False |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@register_model("model_parallel_transformer") |
|
class ModelParallelTransformerModel(TransformerModel): |
|
""" |
|
Model parallel Transformer model. |
|
""" |
|
|
|
@classmethod |
|
def build_embedding(cls, args, dictionary, embed_dim, path=None): |
|
if not has_megatron_submodule: |
|
raise ImportError( |
|
"\n\nPlease install the megatron submodule:" |
|
"\n\n git submodule update --init " |
|
"fairseq/model_parallel/megatron" |
|
) |
|
dictionary.pad_to_multiple_(args.model_parallel_size * 8) |
|
num_embeddings = len(dictionary) |
|
padding_idx = dictionary.pad() |
|
|
|
def _vocab_init(tensor, **kwargs): |
|
nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5) |
|
nn.init.constant_(tensor[1], 0) |
|
|
|
emb = VocabParallelEmbedding( |
|
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init |
|
) |
|
|
|
if path: |
|
raise NotImplementedError( |
|
"Loading of embedding from path is not supported for model parallel" |
|
) |
|
return emb |
|
|
|
@classmethod |
|
def build_encoder(cls, args, src_dict, embed_tokens): |
|
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) |
|
|
|
@classmethod |
|
def build_decoder(cls, args, tgt_dict, embed_tokens): |
|
return ModelParallelTransformerDecoder( |
|
args, |
|
tgt_dict, |
|
embed_tokens, |
|
no_encoder_attn=getattr(args, "no_cross_attention", False), |
|
) |
|
|
|
|
|
class ModelParallelTransformerEncoder(TransformerEncoder): |
|
""" |
|
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer |
|
is a :class:`ModelParallelTransformerEncoderLayer`. |
|
""" |
|
|
|
def __init__(self, args, dictionary, embed_tokens): |
|
super().__init__(args, dictionary, embed_tokens) |
|
|
|
if args.no_final_layer_norm: |
|
self.layer_norm = None |
|
|
|
def build_encoder_layer(self, args): |
|
return ModelParallelTransformerEncoderLayer(args) |
|
|
|
|
|
class ModelParallelTransformerDecoder(TransformerDecoder): |
|
""" |
|
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer |
|
is a :class:`ModelParallelTransformerDecoderLayer`. |
|
""" |
|
|
|
def build_decoder_layer(self, args, no_encoder_attn=False): |
|
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) |
|
|
|
def output_layer(self, features, **kwargs): |
|
"""Project features to the vocabulary size.""" |
|
if not self.share_input_output_embed: |
|
raise NotImplementedError( |
|
"Model parallel training currently requires --share-decoder-input-output-embed" |
|
) |
|
|
|
features = copy_to_model_parallel_region(features) |
|
|
|
|
|
x = self.output_projection(features) |
|
|
|
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": |
|
x = gather_from_model_parallel_region(x).contiguous() |
|
return x |
|
|