|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from monai.utils import optional_import |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") |
|
|
|
|
|
class SABlock(nn.Module): |
|
""" |
|
A self-attention block, based on: "Dosovitskiy et al., |
|
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" |
|
""" |
|
|
|
def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: |
|
""" |
|
Args: |
|
hidden_size: dimension of hidden layer. |
|
num_heads: number of attention heads. |
|
dropout_rate: faction of the input units to drop. |
|
qkv_bias: bias term for the qkv linear layer. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
if not (0 <= dropout_rate <= 1): |
|
raise ValueError("dropout_rate should be between 0 and 1.") |
|
|
|
if hidden_size % num_heads != 0: |
|
raise ValueError("hidden size should be divisible by num_heads.") |
|
|
|
self.num_heads = num_heads |
|
self.out_proj = nn.Linear(hidden_size, hidden_size) |
|
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) |
|
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) |
|
self.out_rearrange = Rearrange("b h l d -> b l (h d)") |
|
self.drop_output = nn.Dropout(dropout_rate) |
|
self.drop_weights = nn.Dropout(dropout_rate) |
|
self.head_dim = hidden_size // num_heads |
|
self.scale = self.head_dim**-0.5 |
|
|
|
def forward(self, x): |
|
output = self.input_rearrange(self.qkv(x)) |
|
q, k, v = output[0], output[1], output[2] |
|
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) |
|
att_mat = self.drop_weights(att_mat) |
|
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) |
|
x = self.out_rearrange(x) |
|
x = self.out_proj(x) |
|
x = self.drop_output(x) |
|
return x, att_mat |
|
|