|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from typing import Optional, Any, Union, Callable |
|
from torch import Tensor |
|
|
|
from .create_act import get_act_layer, get_activation |
|
from timm.models.layers import DropPath |
|
from .layer_norm import LayerNorm |
|
from .pe_encoder import DeepPrompt |
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
r"""TransformerEncoderLayer is made up of self-attn and feedforward network. |
|
This standard encoder layer is based on the paper "Attention Is All You Need". |
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, |
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in |
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement |
|
in a different way during application. |
|
Args: |
|
d_model: the number of expected features in the input (required). |
|
nhead: the number of heads in the multiheadattention models (required). |
|
dim_feedforward: the dimension of the feedforward network model (default=2048). |
|
dropout: the dropout value (default=0.1). |
|
activation: the activation function of the intermediate layer, can be a string |
|
("relu" or "gelu") or a unary callable. Default: relu |
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5). |
|
batch_first: If ``True``, then the input and output tensors are provided |
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature). |
|
norm_first: if ``True``, layer norm is done prior to attention and feedforward |
|
operations, respectivaly. Otherwise it's done after. Default: ``False`` (after). |
|
Examples:: |
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) |
|
>>> src = torch.rand(10, 32, 512) |
|
>>> out = encoder_layer(src) |
|
Alternatively, when ``batch_first`` is ``True``: |
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) |
|
>>> src = torch.rand(32, 10, 512) |
|
>>> out = encoder_layer(src) |
|
Fast path: |
|
forward() will use a special optimized implementation if all of the following |
|
conditions are met: |
|
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor |
|
argument ``requires_grad`` |
|
- training is disabled (using ``.eval()``) |
|
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``) |
|
- norm_first is ``False`` (this restriction may be loosened in the future) |
|
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu`` |
|
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed |
|
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask`` |
|
nor ``src_key_padding_mask`` is passed |
|
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case |
|
unless the caller has manually modified one without modifying the other) |
|
If the optimized implementation is in use, a |
|
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be |
|
passed for ``src`` to represent padding more efficiently than using a padding |
|
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be |
|
returned, and an additional speedup proportional to the fraction of the input that |
|
is padding can be expected. |
|
""" |
|
__constants__ = ['batch_first', 'norm_first'] |
|
|
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, drop_path_ratio: float = 0.1, |
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_scale: bool = False, ls_init_values: float = 1e-3, |
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, |
|
device=None, dtype=None, cfg: dict = None) -> None: |
|
|
|
factory_kwargs = {} |
|
super(TransformerEncoderLayer, self).__init__() |
|
|
|
self.cfg = cfg |
|
|
|
|
|
_torch_version_main = torch.__version__.split('.')[:2] |
|
if (int(_torch_version_main[0]) >= 1) and (int(_torch_version_main[1])) >= 9: |
|
self._torch_nn_new_interface = True |
|
else: |
|
self._torch_nn_new_interface = False |
|
|
|
if self._torch_nn_new_interface: |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, |
|
**factory_kwargs) |
|
else: |
|
factory_kwargs = {} |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, |
|
**factory_kwargs) |
|
|
|
self.batch_first = batch_first |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) |
|
|
|
self.norm_first = norm_first |
|
if self.cfg.SOLVER.FUSED_LAYERNORM: |
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps) |
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps) |
|
else: |
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.drop_path1 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() |
|
|
|
self.layer_scale = layer_scale |
|
if self.layer_scale: |
|
self.gamma_1 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True) |
|
self.gamma_2 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True) |
|
|
|
|
|
if isinstance(activation, str): |
|
activation = get_activation(activation) |
|
|
|
self.activation = activation |
|
|
|
|
|
self.deep_prompt = self.cfg.MODEL.PROMPT_EMBED.DEEP_PROMPT |
|
if self.deep_prompt: |
|
self.deep_prompt_embedding = DeepPrompt(cfg) |
|
|
|
|
|
def __setstate__(self, state): |
|
if 'activation' not in state: |
|
state['activation'] = F.relu |
|
super(TransformerEncoderLayer, self).__setstate__(state) |
|
|
|
def forward(self, |
|
src: Tensor, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
history_states: Optional[Tensor] = None, |
|
**kwargs) -> Tensor: |
|
r"""Pass the input through the encoder layer. |
|
Args: |
|
src: the sequence to the encoder layer (required). |
|
src_mask: the mask for the src sequence (optional). |
|
src_key_padding_mask: the mask for the src keys per batch (optional). |
|
Shape: |
|
see the docs in Transformer class. |
|
""" |
|
|
|
|
|
|
|
if self.batch_first and not self._torch_nn_new_interface: |
|
x = src.transpose(0,1) |
|
if history_states is not None: |
|
history_states = history_states.transpose(0,1) |
|
else: |
|
x = src |
|
|
|
if self.norm_first: |
|
history_states_norm = history_states if (history_states is None) else self.norm1(history_states) |
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, history_states=history_states_norm, **kwargs) |
|
x = x + self._ff_block(self.norm2(x), **kwargs) |
|
else: |
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, history_states=history_states, **kwargs)) |
|
x = self.norm2(x + self._ff_block(x), **kwargs) |
|
|
|
if self.batch_first and not self._torch_nn_new_interface: |
|
x = x.transpose(0, 1) |
|
|
|
return x |
|
|
|
|
|
def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], history_states: Optional[Tensor], |
|
**kwargs) -> Tensor: |
|
|
|
if history_states is not None: |
|
kv = torch.cat( |
|
[history_states, x], |
|
dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0 |
|
) |
|
|
|
else: |
|
kv = x |
|
|
|
if self.deep_prompt: |
|
|
|
deep_prompt_embedding = self.deep_prompt_embedding(x, batch_first=(self.batch_first and self._torch_nn_new_interface), **kwargs) |
|
if self.norm_first: |
|
deep_prompt_embedding = self.norm1(deep_prompt_embedding) |
|
kv = torch.cat([deep_prompt_embedding, kv], dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0) |
|
if attn_mask is not None: |
|
L, S = attn_mask.shape |
|
pe_length = deep_prompt_embedding.shape[1 if |
|
(self.batch_first and self._torch_nn_new_interface) else 0] |
|
attn_mask = torch.cat([torch.zeros((L, pe_length), dtype=attn_mask.dtype, device=attn_mask.device), attn_mask], dim=1) |
|
if key_padding_mask is not None: |
|
if self.batch_first and self._torch_nn_new_interface: |
|
bs, pe_length = deep_prompt_embedding.shape[:2] |
|
else: |
|
pe_length, bs = deep_prompt_embedding.shape[:2] |
|
key_padding_mask = torch.cat( |
|
[torch.zeros((bs, pe_length), dtype=key_padding_mask.dtype, device=key_padding_mask.device), key_padding_mask], dim=1) |
|
|
|
|
|
x = self.self_attn(x, kv, kv, |
|
attn_mask=attn_mask, |
|
key_padding_mask=key_padding_mask, |
|
need_weights=False)[0] |
|
x = self.drop_path1(self.dropout1(x)) |
|
if self.layer_scale: |
|
x = self.gamma_1 * x |
|
return x |
|
|
|
|
|
|
|
def _ff_block(self, x: Tensor, **kwargs) -> Tensor: |
|
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
x = self.drop_path2(self.dropout2(x)) |
|
if self.layer_scale: |
|
x = self.gamma_2 * x |
|
return x |
|
|