File size: 10,780 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import torch
from torch import nn
import torch.nn.functional as F

from typing import Optional, Any, Union, Callable
from torch import Tensor

from .create_act import get_act_layer, get_activation
from timm.models.layers import DropPath
from .layer_norm import LayerNorm
from .pe_encoder import DeepPrompt

class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)
    Fast path:
        forward() will use a special optimized implementation if all of the following
        conditions are met:
        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
          argument ``requires_grad``
        - training is disabled (using ``.eval()``)
        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
        - norm_first is ``False`` (this restriction may be loosened in the future)
        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
          nor ``src_key_padding_mask`` is passed
        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
          unless the caller has manually modified one without modifying the other)
        If the optimized implementation is in use, a
        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
        passed for ``src`` to represent padding more efficiently than using a padding
        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
        returned, and an additional speedup proportional to the fraction of the input that
        is padding can be expected.
    """
    __constants__ = ['batch_first', 'norm_first'] # we inherit this variable from pytorch's code for jit

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, drop_path_ratio: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_scale: bool = False, ls_init_values: float = 1e-3,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None, cfg: dict = None) -> None:
        #
        factory_kwargs = {}
        super(TransformerEncoderLayer, self).__init__()

        self.cfg = cfg

        # The interface of nn.MultiheadAttention changed since torch 1.9.0.
        _torch_version_main = torch.__version__.split('.')[:2]
        if (int(_torch_version_main[0]) >= 1) and (int(_torch_version_main[1])) >= 9:
            self._torch_nn_new_interface = True
        else:
            self._torch_nn_new_interface = False

        if self._torch_nn_new_interface:
            factory_kwargs = {'device': device, 'dtype': dtype}
            self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                                **factory_kwargs)
        else:
            factory_kwargs = {}
            self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout,
                                                **factory_kwargs)

        self.batch_first = batch_first

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
        if self.cfg.SOLVER.FUSED_LAYERNORM:
            self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
            self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
        else:
            self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.drop_path1 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.drop_path2 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()

        self.layer_scale = layer_scale
        if self.layer_scale:
            self.gamma_1 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
            self.gamma_2 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = get_activation(activation)

        self.activation = activation

        # prompt embedding setup
        self.deep_prompt = self.cfg.MODEL.PROMPT_EMBED.DEEP_PROMPT
        if self.deep_prompt:
            self.deep_prompt_embedding = DeepPrompt(cfg)


    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self,
                src: Tensor,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                history_states: Optional[Tensor] = None,
                **kwargs) -> Tensor:
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        if self.batch_first and not self._torch_nn_new_interface:
            x = src.transpose(0,1)
            if history_states is not None:
                history_states = history_states.transpose(0,1)
        else:
            x = src

        if self.norm_first:
            history_states_norm = history_states if (history_states is None) else self.norm1(history_states)
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, history_states=history_states_norm, **kwargs)
            x = x + self._ff_block(self.norm2(x), **kwargs)
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, history_states=history_states, **kwargs))
            x = self.norm2(x + self._ff_block(x), **kwargs)

        if self.batch_first and not self._torch_nn_new_interface:
            x = x.transpose(0, 1)

        return x

    # self-attention block
    def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], history_states: Optional[Tensor],
                  **kwargs) -> Tensor:

        if history_states is not None:
            kv = torch.cat(
                [history_states, x],
                dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0
            )
            # TODO: changes for attn_mask and key_padding_mask
        else:
            kv = x

        if self.deep_prompt:

            deep_prompt_embedding = self.deep_prompt_embedding(x, batch_first=(self.batch_first and self._torch_nn_new_interface), **kwargs)
            if self.norm_first:
                deep_prompt_embedding = self.norm1(deep_prompt_embedding)
            kv = torch.cat([deep_prompt_embedding, kv], dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0)
            if attn_mask is not None:
                L, S = attn_mask.shape
                pe_length = deep_prompt_embedding.shape[1 if
                                                        (self.batch_first and self._torch_nn_new_interface) else 0]  # length, bs, hidden_size
                attn_mask = torch.cat([torch.zeros((L, pe_length), dtype=attn_mask.dtype, device=attn_mask.device), attn_mask], dim=1)
            if key_padding_mask is not None:
                if self.batch_first and self._torch_nn_new_interface:
                    bs, pe_length = deep_prompt_embedding.shape[:2]
                else:
                    pe_length, bs = deep_prompt_embedding.shape[:2]
                key_padding_mask = torch.cat(
                    [torch.zeros((bs, pe_length), dtype=key_padding_mask.dtype, device=key_padding_mask.device), key_padding_mask], dim=1)


        x = self.self_attn(x, kv, kv,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        x = self.drop_path1(self.dropout1(x))
        if self.layer_scale:
            x = self.gamma_1 * x
        return x


    # feed forward block
    def _ff_block(self, x: Tensor, **kwargs) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = self.drop_path2(self.dropout2(x))
        if self.layer_scale:
            x = self.gamma_2 * x
        return x