File size: 2,847 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer block.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=not getattr(args, "cross_self_attention", False),
)
def build_encoder_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
|