unit_test / uniperceiver /modeling /layers /transformer_encoder_layer.py
herrius's picture
Upload 259 files
32b542e
import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, Any, Union, Callable
from torch import Tensor
from .create_act import get_act_layer, get_activation
from timm.models.layers import DropPath
from .layer_norm import LayerNorm
from .pe_encoder import DeepPrompt
class TransformerEncoderLayer(nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first: if ``True``, layer norm is done prior to attention and feedforward
operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
Alternatively, when ``batch_first`` is ``True``:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
Fast path:
forward() will use a special optimized implementation if all of the following
conditions are met:
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
argument ``requires_grad``
- training is disabled (using ``.eval()``)
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
- norm_first is ``False`` (this restriction may be loosened in the future)
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
nor ``src_key_padding_mask`` is passed
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
unless the caller has manually modified one without modifying the other)
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
passed for ``src`` to represent padding more efficiently than using a padding
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
returned, and an additional speedup proportional to the fraction of the input that
is padding can be expected.
"""
__constants__ = ['batch_first', 'norm_first'] # we inherit this variable from pytorch's code for jit
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, drop_path_ratio: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_scale: bool = False, ls_init_values: float = 1e-3,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None, cfg: dict = None) -> None:
#
factory_kwargs = {}
super(TransformerEncoderLayer, self).__init__()
self.cfg = cfg
# The interface of nn.MultiheadAttention changed since torch 1.9.0.
_torch_version_main = torch.__version__.split('.')[:2]
if (int(_torch_version_main[0]) >= 1) and (int(_torch_version_main[1])) >= 9:
self._torch_nn_new_interface = True
else:
self._torch_nn_new_interface = False
if self._torch_nn_new_interface:
factory_kwargs = {'device': device, 'dtype': dtype}
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
else:
factory_kwargs = {}
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout,
**factory_kwargs)
self.batch_first = batch_first
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
if self.cfg.SOLVER.FUSED_LAYERNORM:
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
else:
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.drop_path1 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.drop_path2 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.layer_scale = layer_scale
if self.layer_scale:
self.gamma_1 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
self.gamma_2 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = get_activation(activation)
self.activation = activation
# prompt embedding setup
self.deep_prompt = self.cfg.MODEL.PROMPT_EMBED.DEEP_PROMPT
if self.deep_prompt:
self.deep_prompt_embedding = DeepPrompt(cfg)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
history_states: Optional[Tensor] = None,
**kwargs) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
if self.batch_first and not self._torch_nn_new_interface:
x = src.transpose(0,1)
if history_states is not None:
history_states = history_states.transpose(0,1)
else:
x = src
if self.norm_first:
history_states_norm = history_states if (history_states is None) else self.norm1(history_states)
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, history_states=history_states_norm, **kwargs)
x = x + self._ff_block(self.norm2(x), **kwargs)
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, history_states=history_states, **kwargs))
x = self.norm2(x + self._ff_block(x), **kwargs)
if self.batch_first and not self._torch_nn_new_interface:
x = x.transpose(0, 1)
return x
# self-attention block
def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], history_states: Optional[Tensor],
**kwargs) -> Tensor:
if history_states is not None:
kv = torch.cat(
[history_states, x],
dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0
)
# TODO: changes for attn_mask and key_padding_mask
else:
kv = x
if self.deep_prompt:
deep_prompt_embedding = self.deep_prompt_embedding(x, batch_first=(self.batch_first and self._torch_nn_new_interface), **kwargs)
if self.norm_first:
deep_prompt_embedding = self.norm1(deep_prompt_embedding)
kv = torch.cat([deep_prompt_embedding, kv], dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0)
if attn_mask is not None:
L, S = attn_mask.shape
pe_length = deep_prompt_embedding.shape[1 if
(self.batch_first and self._torch_nn_new_interface) else 0] # length, bs, hidden_size
attn_mask = torch.cat([torch.zeros((L, pe_length), dtype=attn_mask.dtype, device=attn_mask.device), attn_mask], dim=1)
if key_padding_mask is not None:
if self.batch_first and self._torch_nn_new_interface:
bs, pe_length = deep_prompt_embedding.shape[:2]
else:
pe_length, bs = deep_prompt_embedding.shape[:2]
key_padding_mask = torch.cat(
[torch.zeros((bs, pe_length), dtype=key_padding_mask.dtype, device=key_padding_mask.device), key_padding_mask], dim=1)
x = self.self_attn(x, kv, kv,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
x = self.drop_path1(self.dropout1(x))
if self.layer_scale:
x = self.gamma_1 * x
return x
# feed forward block
def _ff_block(self, x: Tensor, **kwargs) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = self.drop_path2(self.dropout2(x))
if self.layer_scale:
x = self.gamma_2 * x
return x