|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from monai.networks.blocks.mlp import MLPBlock |
|
from typing import Sequence, Union |
|
import torch |
|
import torch.nn as nn |
|
|
|
from ..nn.selfattention import SABlock |
|
|
|
class TransformerBlock(nn.Module): |
|
""" |
|
A transformer block, based on: "Dosovitskiy et al., |
|
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" |
|
""" |
|
|
|
def __init__( |
|
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False |
|
) -> None: |
|
""" |
|
Args: |
|
hidden_size: dimension of hidden layer. |
|
mlp_dim: dimension of feedforward layer. |
|
num_heads: number of attention heads. |
|
dropout_rate: faction of the input units to drop. |
|
qkv_bias: apply bias term for the qkv linear layer |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
if not (0 <= dropout_rate <= 1): |
|
raise ValueError("dropout_rate should be between 0 and 1.") |
|
|
|
if hidden_size % num_heads != 0: |
|
raise ValueError("hidden_size should be divisible by num_heads.") |
|
|
|
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) |
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias) |
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
|
|
def forward(self, x, return_attention=False): |
|
y, attn = self.attn(self.norm1(x)) |
|
if return_attention: |
|
return attn |
|
x = x + y |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|