File size: 2,847 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer


try:
    from fairseq.model_parallel.megatron.mpu import (
        ColumnParallelLinear,
        RowParallelLinear,
    )

    has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
    has_megatron_submodule = False


class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
    """Encoder layer block over multiple gpus.

    See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
    """

    def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
        if q_noise > 0:
            raise NotImplementedError
        return ColumnParallelLinear(input_dim, output_dim, gather_output=False)

    def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
        if q_noise > 0:
            raise NotImplementedError
        return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)

    def build_self_attention(self, embed_dim, args, **unused_kwargs):
        return ModelParallelMultiheadAttention(
            embed_dim,
            args.encoder_attention_heads,
            dropout=args.attention_dropout,
            self_attention=True,
        )


class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
    """Decoder layer block.

    See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
    """

    def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
        if q_noise > 0:
            raise NotImplementedError
        return ColumnParallelLinear(input_dim, output_dim, gather_output=False)

    def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
        if q_noise > 0:
            raise NotImplementedError
        return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)

    def build_self_attention(self, embed_dim, args, **unused_kwargs):
        return ModelParallelMultiheadAttention(
            embed_dim=embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            self_attention=not getattr(args, "cross_self_attention", False),
        )

    def build_encoder_attention(self, embed_dim, args, **unused_kwargs):
        return ModelParallelMultiheadAttention(
            embed_dim=embed_dim,
            num_heads=args.decoder_attention_heads,
            kdim=getattr(args, "encoder_embed_dim", None),
            vdim=getattr(args, "encoder_embed_dim", None),
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )