ECG_Delineation / res /impl /SegFormer.py
wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/2105.15203
- ref:
- encoder:
- https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/mit.py
- decoder:
- https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py
"""
import torch
from torch import nn
from torch.functional import F
import math
from einops import rearrange
class MixFFN(nn.Module):
def __init__(self, embed_dim, channels, dropout=0.0):
super().__init__()
self.layers = nn.Sequential(
nn.Conv1d( # fc1
in_channels=embed_dim, out_channels=channels, kernel_size=1, stride=1
),
nn.Conv1d( # position embed (depthwise-separable)
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
padding=1,
groups=channels,
),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv1d( # fc2
in_channels=channels, out_channels=embed_dim, kernel_size=1
),
nn.Dropout(dropout),
)
def forward(self, x):
out = x.transpose(1, 2)
out = self.layers(out)
out = out.transpose(1, 2)
return out
class EfficientMultiheadAttention(nn.Module):
"""
PVT(Pyramid Vision Transformer)μ—μ„œ μ‚¬μš©ν•œ Spatial-Reduction Attention 을 차용
λ³€μˆ˜λͺ… 쀑 sr 은 Spatial-Reduction 의 μ•½μ–΄
"""
def __init__(
self, embed_dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1
):
super().__init__()
assert (
embed_dim % num_heads == 0
), f"dim {embed_dim} should be divided by num_heads {num_heads}."
self.num_heads = num_heads
head_dim = embed_dim // num_heads
self.scale = head_dim**-0.5
self.q = nn.Linear(embed_dim, embed_dim, bias=False)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv1d(
embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
B, N, C = x.shape
q = self.q(x)
q = rearrange(q, "b n (h c) -> b h n c", h=self.num_heads)
if self.sr_ratio > 1:
x_ = x.transpose(1, 2)
x_ = self.sr(x_).transpose(1, 2)
x_ = self.norm(x_)
kv = self.kv(x_)
kv = rearrange(
kv,
"b n (two_heads h c) -> two_heads b h n c",
two_heads=2,
h=self.num_heads,
)
else:
kv = self.kv(x)
kv = rearrange(
kv,
"b n (two_heads h c) -> two_heads b h n c",
two_heads=2,
h=self.num_heads,
)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ffn_channels, dropout=0.2, sr_ratio=1):
super().__init__()
self.attn = nn.Sequential(
nn.LayerNorm(embed_dim),
EfficientMultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
attn_drop=dropout,
proj_drop=dropout,
sr_ratio=sr_ratio,
),
)
self.ffn = nn.Sequential(
nn.LayerNorm(embed_dim),
MixFFN(embed_dim=embed_dim, channels=ffn_channels, dropout=dropout),
)
def forward(self, x):
x = x + self.attn(x)
x = x + self.ffn(x)
return x
class PatchEmbed(nn.Module):
def __init__(
self,
in_channels=1,
embed_dim=1024,
kernel_size=7,
stride=4,
padding=3,
bias=False,
):
super().__init__()
self.projection = nn.Conv1d(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, x: torch.Tensor):
return self.projection(x).transpose(1, 2)
class MiT(nn.Module):
"""MixVisionTransformer"""
def __init__(
self,
embed_dim=512,
num_blocks=[2, 2, 6, 2],
num_heads=[1, 2, "ceil"],
sr_ratios=[1, 2, "ceil"],
mlp_ratio=4,
dropout=0.2,
):
super().__init__()
num_stages = len(num_blocks)
round_func = getattr(math, num_heads[2]) # math.ceil or match.floor
num_heads = [
round_func((num_heads[0] * math.pow(num_heads[1], itr)))
for itr in range(num_stages)
]
round_func = getattr(math, sr_ratios[2]) # math.ceil or match.floor
sr_ratios = [
round_func(sr_ratios[0] * math.pow(sr_ratios[1], itr))
for itr in range(num_stages)
]
sr_ratios.reverse()
self.embed_dims = [embed_dim * num_head for num_head in num_heads]
patch_kernel_sizes = [7] # [7, 3, 3, ..]
patch_kernel_sizes.extend([3] * (num_stages - 1))
patch_strides = [4] # [4, 2, 2, ..]
patch_strides.extend([2] * (num_stages - 1))
patch_paddings = [3] # [3, 1, 1, ..]
patch_paddings.extend([1] * (num_stages - 1))
in_channels = 1
self.stages = nn.ModuleList()
for i, num_block in enumerate(num_blocks):
patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dim=self.embed_dims[i],
kernel_size=patch_kernel_sizes[i],
stride=patch_strides[i],
padding=patch_paddings[i],
)
blocks = nn.ModuleList(
[
TransformerBlock(
embed_dim=self.embed_dims[i],
num_heads=num_heads[i],
ffn_channels=mlp_ratio * self.embed_dims[i],
dropout=dropout,
sr_ratio=sr_ratios[i],
)
for _ in range(num_block)
]
)
in_channels = self.embed_dims[i]
norm = nn.LayerNorm(self.embed_dims[i])
self.stages.append(nn.ModuleList([patch_embed, blocks, norm]))
def forward(self, x):
outs = []
for stage in self.stages:
x = stage[0](x) # patch embed
for block in stage[1]: # transformer blocks
x = block(x)
x = stage[2](x) # norm
x = x.transpose(1, 2)
outs.append(x)
return outs
class SegFormer(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = int(config.embed_dim)
num_blocks = config.num_blocks
num_heads = config.num_heads
assert len(num_heads) == 3 and num_heads[2] in ["floor", "ceil"]
sr_ratios = config.sr_ratios
assert len(sr_ratios) == 3 and sr_ratios[2] in ["floor", "ceil"]
mlp_ratio = int(config.mlp_ratio)
dropout = float(config.dropout)
decoder_channels = int(config.decoder_channels)
self.interpolate_mode = str(config.interpolate_mode)
output_size = int(config.output_size)
self.MiT = MiT(embed_dim, num_blocks, num_heads, sr_ratios, mlp_ratio, dropout)
num_stages = len(num_blocks)
self.decode_mlps = nn.ModuleList(
[
nn.Conv1d(self.MiT.embed_dims[i], decoder_channels, 1, bias=False)
for i in range(num_stages)
]
)
self.decode_fusion = nn.Conv1d(
decoder_channels * num_stages, decoder_channels, 1, bias=False
)
self.cls = nn.Conv1d(decoder_channels, output_size, 1, bias=False)
def forward(self, input: torch.Tensor, y=None):
output = input
output = self.MiT(output)
for i, (_output, decode_mlp) in enumerate(zip(output, self.decode_mlps)):
_output = decode_mlp(_output)
if i != 0:
_output = F.interpolate(
_output, size=output[0].shape[2], mode=self.interpolate_mode
)
output[i] = _output
output = torch.concat(output, dim=1)
output = self.decode_fusion(output)
output = self.cls(output)
return F.interpolate(output, size=input.shape[2], mode=self.interpolate_mode)