# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( Embedding, TransformerDecoderEmbedding, TransformerDecoderLayer, TransformerDecoderOutputLayer, TransformerEncoderEmbedding, TransformerEncoderLayer, TransformerEncoderLayerNorm, ) from fairseq.models import ( BaseFairseqModel, FairseqDecoder, FairseqEncoder, register_model, register_model_architecture, ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( base_architecture, transformer_iwslt_de_en, transformer_wmt_en_de_big, ) from fairseq.modules import SinusoidalPositionalEmbedding logger = logging.getLogger(__name__) DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 TORCH_PIPE = False RPC_INIT = False def import_pipe(): global TORCH_PIPE global RPC_INIT try: from torch.distributed.pipeline.sync import Pipe # noqa global Pipe from torch.distributed.pipeline.sync.utils import partition_model global partition_model from torch.distributed import rpc import tempfile TORCH_PIPE = True # Initialize single process RPC agent since TORCH_PIPE requires # RRef. RRef depends on RPC being initialized and as a result we initialize # RPC with a single node. tmpfile = tempfile.NamedTemporaryFile() if not RPC_INIT: rpc.init_rpc( name="worker", rank=0, world_size=1, rpc_backend_options=rpc.TensorPipeRpcBackendOptions( init_method="file://{}".format(tmpfile.name), ) ) RPC_INIT = True logger.info('Using torch pipe') except ImportError: try: from fairscale.nn import Pipe # noqa logger.info('Using fairscale pipe') except ImportError: raise ImportError("Please install fairscale with: pip install fairscale") @register_model("pipeline_parallel_transformer") class PipelineParallelTransformerModel(BaseFairseqModel): def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): import_pipe() super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) encoder_module_list = ( [encoder.embedding_layer] + list(encoder.encoder_layers) + [encoder.final_layer_norm] ) self.num_encoder_modules = len(encoder_module_list) decoder_module_list = ( [decoder.embedding_layer] + list(decoder.decoder_layers) + [decoder.decoder_output_layer] ) self.num_decoder_modules = len(decoder_module_list) module_list = encoder_module_list + decoder_module_list self.devices = devices if TORCH_PIPE: self.model = Pipe( partition_model(nn.Sequential(*module_list), balance, devices), chunks=chunks, checkpoint=checkpoint, ) else: self.model = Pipe( nn.Sequential(*module_list), balance=balance, devices=devices, chunks=chunks, checkpoint=checkpoint, ) self.encoder_max_positions = self.max_positions_helper( encoder.embedding_layer, "max_source_positions" ) self.decoder_max_positions = self.max_positions_helper( decoder.embedding_layer, "max_target_positions" ) self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) # Note: To be populated during inference self.encoder = None self.decoder = None def forward(self, src_tokens, src_lengths, prev_output_tokens): if self.training: input_lst = [src_tokens, src_lengths, prev_output_tokens] input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) if TORCH_PIPE: return self.model(input).local_value() else: return self.model(input) else: assert self.encoder is not None and self.decoder is not None, ( "encoder and decoder need to be initialized by " + "calling the `prepare_for_inference_()` method" ) encoder_output_tuple = self.encoder(input) return self.decoder(encoder_output_tuple) def prepare_for_inference_(self, cfg): if self.encoder is not None and self.decoder is not None: logger.info("Encoder and Decoder already initialized") return encoder_module_list = [] decoder_module_list = [] module_count = 0 for partition in self.model.partitions: for module in partition: if module_count < self.num_encoder_modules: encoder_module_list.append(module) else: decoder_module_list.append(module) module_count += 1 self.model = None self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list) self.decoder = TransformerDecoder( cfg.distributed_training, None, None, decoder_module_list=decoder_module_list ) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--activation-fn', choices=utils.get_available_activation_fns(), help='activation function to use') parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') parser.add_argument('--attention-dropout', type=float, metavar='D', help='dropout probability for attention weights') parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', help='dropout probability after activation in FFN.') parser.add_argument('--encoder-embed-path', type=str, metavar='STR', help='path to pre-trained encoder embedding') parser.add_argument('--encoder-embed-dim', type=int, metavar='N', help='encoder embedding dimension') parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', help='encoder embedding dimension for FFN') parser.add_argument('--encoder-layers', type=int, metavar='N', help='num encoder layers') parser.add_argument('--encoder-attention-heads', type=int, metavar='N', help='num encoder attention heads') parser.add_argument('--encoder-normalize-before', action='store_true', help='apply layernorm before each encoder block') parser.add_argument('--encoder-learned-pos', action='store_true', help='use learned positional embeddings in the encoder') parser.add_argument('--decoder-embed-path', type=str, metavar='STR', help='path to pre-trained decoder embedding') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', help='decoder embedding dimension for FFN') parser.add_argument('--decoder-layers', type=int, metavar='N', help='num decoder layers') parser.add_argument('--decoder-attention-heads', type=int, metavar='N', help='num decoder attention heads') parser.add_argument('--decoder-learned-pos', action='store_true', help='use learned positional embeddings in the decoder') parser.add_argument('--decoder-normalize-before', action='store_true', help='apply layernorm before each decoder block') parser.add_argument('--share-decoder-input-output-embed', action='store_true', help='share decoder input and output embeddings') parser.add_argument('--share-all-embeddings', action='store_true', help='share encoder, decoder and output embeddings' ' (requires shared dictionary and embed dim)') parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', help='if set, disables positional embeddings (outside self attention)') parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion'), parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', help='sets adaptive softmax dropout for the tail projections') parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1, help='Number of embedding layer chunks (enables more even distribution' 'of optimizer states across data parallel nodes' 'when using optimizer state sharding and' 'a big embedding vocabulary)') # fmt: on @classmethod def build_model_base(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, "max_source_positions"): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if not hasattr(args, "max_target_positions"): args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): assert embed_dim % num_embed_chunks == 0, ( f"Number of embedding chunks = {num_embed_chunks} should be " + f"divisible by the embedding dimension = {embed_dim}" ) assert path is None or num_embed_chunks == 1, ( "Loading embedding from a path with number of embedding chunks > 1" + " is not yet supported" ) num_embeddings = len(dictionary) padding_idx = dictionary.pad() # if provided, load from preloaded dictionaries if path: emb = Embedding(num_embeddings, embed_dim, padding_idx) embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) else: embed_chunk_dim = embed_dim // num_embed_chunks emb = nn.ModuleList() for i in range(num_embed_chunks): emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) return emb num_embed_chunks = args.num_embedding_chunks if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks, ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: assert args.share_decoder_input_output_embed or num_embed_chunks == 1, ( "Not sharing decoder I/O embeddings is not yet supported with number of " + "embedding chunks > 1" ) encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks, ) decoder_embed_tokens = build_embedding( tgt_dict, args.decoder_embed_dim, args.decoder_embed_path, num_embed_chunks, ) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) return (encoder, decoder) @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerEncoder(args, src_dict, embed_tokens) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): return TransformerDecoder(args, tgt_dict, embed_tokens) @classmethod def build_model(cls, args, task): encoder, decoder = cls.build_model_base(args, task) return PipelineParallelTransformerModel( encoder=encoder, decoder=decoder, balance=utils.eval_str_list(args.pipeline_balance, type=int), devices=utils.eval_str_list(args.pipeline_devices, type=int), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) def output_layer(self, features, **kwargs): """Project features to the default output size (typically vocabulary size).""" return self.decoder.output_layer(features, **kwargs) def max_positions(self): """Maximum length supported by the model.""" return (self.encoder_max_positions, self.decoder_max_positions) def max_positions_helper( self, embedding_layer, max_positions_field="max_source_positions" ): """Maximum input length supported by the encoder or decoder.""" if embedding_layer.embed_positions is None: return getattr(embedding_layer, max_positions_field) return min( getattr(embedding_layer, max_positions_field), embedding_layer.embed_positions.max_positions, ) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output, target=target) return out.exp_() if not log_probs else out # A Pipe() module returns a tuple of tensors as the output. # In this case, the tuple has one element - the output tensor of logits logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=False) else: return utils.softmax(logits, dim=-1, onnx_trace=False) def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder_max_positions def load_state_dict(self, state_dict, strict=True, model_cfg=None): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ self.upgrade_state_dict(state_dict) is_regular_transformer = not any("model.partitions" in k for k in state_dict) if is_regular_transformer: state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) return super().load_state_dict(state_dict, strict) def convert_to_pipeline_parallel_state_dict(self, state_dict): new_state_dict = self.state_dict() encoder_layer_idx = 0 decoder_layer_idx = 0 encoder_key_suffixes = [ "self_attn.k_proj.weight", "self_attn.k_proj.bias", "self_attn.v_proj.weight", "self_attn.v_proj.bias", "self_attn.q_proj.weight", "self_attn.q_proj.bias", "self_attn.out_proj.weight", "self_attn.out_proj.bias", "self_attn_layer_norm.weight", "self_attn_layer_norm.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "final_layer_norm.weight", "final_layer_norm.bias", ] decoder_key_suffixes = [ "self_attn.k_proj.weight", "self_attn.k_proj.bias", "self_attn.v_proj.weight", "self_attn.v_proj.bias", "self_attn.q_proj.weight", "self_attn.q_proj.bias", "self_attn.out_proj.weight", "self_attn.out_proj.bias", "self_attn_layer_norm.weight", "self_attn_layer_norm.bias", "encoder_attn.k_proj.weight", "encoder_attn.k_proj.bias", "encoder_attn.v_proj.weight", "encoder_attn.v_proj.bias", "encoder_attn.q_proj.weight", "encoder_attn.q_proj.bias", "encoder_attn.out_proj.weight", "encoder_attn.out_proj.bias", "encoder_attn_layer_norm.weight", "encoder_attn_layer_norm.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "final_layer_norm.weight", "final_layer_norm.bias", ] for pid, partition in enumerate(self.model.partitions): logger.info(f"Begin Partition {pid}") for mid, module in enumerate(partition): # fmt: off if isinstance(module, TransformerEncoderEmbedding): new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'] new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor'] if isinstance(module, TransformerEncoderLayer): for suffix in encoder_key_suffixes: new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}'] encoder_layer_idx += 1 if isinstance(module, TransformerDecoderLayer): for suffix in decoder_key_suffixes: new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}'] decoder_layer_idx += 1 if isinstance(module, TransformerEncoderLayerNorm): if 'encoder.layer_norm.weight' in state_dict: new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight'] new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias'] if isinstance(module, TransformerDecoderEmbedding): new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'] new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor'] if isinstance(module, TransformerDecoderOutputLayer): new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight'] # fmt: on return new_state_dict class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) import_pipe() self.use_pipeline = encoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]) if isinstance(embed_tokens, nn.ModuleList): emb_dim = sum(e.embedding_dim for e in embed_tokens) else: emb_dim = embed_tokens.embedding_dim self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) else: encoder_balance = utils.eval_str_list( args.pipeline_encoder_balance, type=int ) encoder_devices = utils.eval_str_list( args.pipeline_encoder_devices, type=int ) assert sum(encoder_balance) == len(encoder_module_list), ( f"Sum of encoder_balance={encoder_balance} is not equal " + f"to num_encoder_modules={len(encoder_module_list)}" ) if TORCH_PIPE: self.model = Pipe( module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) else: self.model = Pipe( module=nn.Sequential(*encoder_module_list), balance=encoder_balance, devices=encoder_devices, chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) def forward(self, src_tokens, src_lengths): """ Args: input_tuple( src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` ) Returns: output_tuple( - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - prev_output_tokens - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. ) """ dummy_prev_output_tokens = torch.zeros( 1, dtype=src_tokens.dtype, device=src_tokens.device ) input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) if TORCH_PIPE: encoder_out = self.model(input_tuple).local_value() else: encoder_out = self.model(input_tuple) else: encoder_embed_output_tuple = self.embedding_layer(input_tuple) encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) encoder_out = self.final_layer_norm(encoder_layers_output) # first element is the encoder output # second element is the encoder padding mask # the remaining elements of EncoderOut are not computed by # the PipelineParallelTransformer return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None) def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out.encoder_out is not None: encoder_out = encoder_out._replace( encoder_out=encoder_out.encoder_out.index_select(1, new_order) ) if encoder_out.encoder_padding_mask is not None: encoder_out = encoder_out._replace( encoder_padding_mask=encoder_out.encoder_padding_mask.index_select( 0, new_order ) ) if encoder_out.encoder_embedding is not None: encoder_out = encoder_out._replace( encoder_embedding=encoder_out.encoder_embedding.index_select( 0, new_order ) ) if encoder_out.encoder_states is not None: for idx, state in enumerate(encoder_out.encoder_states): encoder_out.encoder_states[idx] = state.index_select(1, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embedding_layer.embed_positions is None: return self.embedding_layer.max_source_positions return min( self.embedding_layer.max_source_positions, self.embedding_layer.embed_positions.max_positions, ) class TransformerDecoder(FairseqDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None, ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) import_pipe() self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) self.decoder_layers = nn.Sequential(*[ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary ) else: decoder_balance = utils.eval_str_list( args.pipeline_decoder_balance, type=int ) decoder_devices = utils.eval_str_list( args.pipeline_decoder_devices, type=int ) assert sum(decoder_balance) == len(decoder_module_list), ( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}" ) if TORCH_PIPE: self.model = Pipe( module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) else: self.model = Pipe( module=nn.Sequential(*decoder_module_list), balance=decoder_balance, devices=decoder_devices, chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) def forward( self, prev_output_tokens, encoder_out=None, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ input_tuple = ( encoder_out.encoder_out, encoder_out.encoder_padding_mask, prev_output_tokens, ) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) if TORCH_PIPE: return (self.model(input_tuple).local_value(),) else: return (self.model(input_tuple),) else: embed_layer_output = self.embedding_layer(input_tuple) state = self.decoder_layers(embed_layer_output) return (self.decoder_output_layer(state),) def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: return F.linear(features, self.embed_tokens.weight) else: return F.linear(features, self.embed_out) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embedding_layer.embed_positions is None: return self.embedding_layer.max_target_positions return min( self.embedding_layer.max_target_positions, self.embedding_layer.embed_positions.max_positions, ) def buffered_future_mask(self, tensor): dim = tensor.size(0) if ( not hasattr(self, "_future_mask") or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim ): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 ) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) if k in state_dict: state_dict[ "{}.layers.{}.{}.{}".format(name, i, new, m) ] = state_dict[k] del state_dict[k] version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict @register_model_architecture( "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel" ) def transformer_iwslt_de_en_dist(args): transformer_iwslt_de_en(args) @register_model_architecture( "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel" ) def transformer_wmt_en_de_big_dist(args): transformer_wmt_en_de_big(args)