Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/2105.15203 | |
- ref: | |
- encoder: | |
- https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py | |
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/mit.py | |
- decoder: | |
- https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py | |
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py | |
""" | |
import torch | |
from torch import nn | |
from torch.functional import F | |
import math | |
from einops import rearrange | |
class MixFFN(nn.Module): | |
def __init__(self, embed_dim, channels, dropout=0.0): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Conv1d( # fc1 | |
in_channels=embed_dim, out_channels=channels, kernel_size=1, stride=1 | |
), | |
nn.Conv1d( # position embed (depthwise-separable) | |
in_channels=channels, | |
out_channels=channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=channels, | |
), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Conv1d( # fc2 | |
in_channels=channels, out_channels=embed_dim, kernel_size=1 | |
), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
out = x.transpose(1, 2) | |
out = self.layers(out) | |
out = out.transpose(1, 2) | |
return out | |
class EfficientMultiheadAttention(nn.Module): | |
""" | |
PVT(Pyramid Vision Transformer)μμ μ¬μ©ν Spatial-Reduction Attention μ μ°¨μ© | |
λ³μλͺ μ€ sr μ Spatial-Reduction μ μ½μ΄ | |
""" | |
def __init__( | |
self, embed_dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1 | |
): | |
super().__init__() | |
assert ( | |
embed_dim % num_heads == 0 | |
), f"dim {embed_dim} should be divided by num_heads {num_heads}." | |
self.num_heads = num_heads | |
head_dim = embed_dim // num_heads | |
self.scale = head_dim**-0.5 | |
self.q = nn.Linear(embed_dim, embed_dim, bias=False) | |
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=False) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(embed_dim, embed_dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.sr_ratio = sr_ratio | |
if sr_ratio > 1: | |
self.sr = nn.Conv1d( | |
embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio | |
) | |
self.norm = nn.LayerNorm(embed_dim) | |
def forward(self, x): | |
B, N, C = x.shape | |
q = self.q(x) | |
q = rearrange(q, "b n (h c) -> b h n c", h=self.num_heads) | |
if self.sr_ratio > 1: | |
x_ = x.transpose(1, 2) | |
x_ = self.sr(x_).transpose(1, 2) | |
x_ = self.norm(x_) | |
kv = self.kv(x_) | |
kv = rearrange( | |
kv, | |
"b n (two_heads h c) -> two_heads b h n c", | |
two_heads=2, | |
h=self.num_heads, | |
) | |
else: | |
kv = self.kv(x) | |
kv = rearrange( | |
kv, | |
"b n (two_heads h c) -> two_heads b h n c", | |
two_heads=2, | |
h=self.num_heads, | |
) | |
k, v = kv[0], kv[1] | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2) | |
x = x.reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class TransformerBlock(nn.Module): | |
def __init__(self, embed_dim, num_heads, ffn_channels, dropout=0.2, sr_ratio=1): | |
super().__init__() | |
self.attn = nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
EfficientMultiheadAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
attn_drop=dropout, | |
proj_drop=dropout, | |
sr_ratio=sr_ratio, | |
), | |
) | |
self.ffn = nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
MixFFN(embed_dim=embed_dim, channels=ffn_channels, dropout=dropout), | |
) | |
def forward(self, x): | |
x = x + self.attn(x) | |
x = x + self.ffn(x) | |
return x | |
class PatchEmbed(nn.Module): | |
def __init__( | |
self, | |
in_channels=1, | |
embed_dim=1024, | |
kernel_size=7, | |
stride=4, | |
padding=3, | |
bias=False, | |
): | |
super().__init__() | |
self.projection = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=embed_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, x: torch.Tensor): | |
return self.projection(x).transpose(1, 2) | |
class MiT(nn.Module): | |
"""MixVisionTransformer""" | |
def __init__( | |
self, | |
embed_dim=512, | |
num_blocks=[2, 2, 6, 2], | |
num_heads=[1, 2, "ceil"], | |
sr_ratios=[1, 2, "ceil"], | |
mlp_ratio=4, | |
dropout=0.2, | |
): | |
super().__init__() | |
num_stages = len(num_blocks) | |
round_func = getattr(math, num_heads[2]) # math.ceil or match.floor | |
num_heads = [ | |
round_func((num_heads[0] * math.pow(num_heads[1], itr))) | |
for itr in range(num_stages) | |
] | |
round_func = getattr(math, sr_ratios[2]) # math.ceil or match.floor | |
sr_ratios = [ | |
round_func(sr_ratios[0] * math.pow(sr_ratios[1], itr)) | |
for itr in range(num_stages) | |
] | |
sr_ratios.reverse() | |
self.embed_dims = [embed_dim * num_head for num_head in num_heads] | |
patch_kernel_sizes = [7] # [7, 3, 3, ..] | |
patch_kernel_sizes.extend([3] * (num_stages - 1)) | |
patch_strides = [4] # [4, 2, 2, ..] | |
patch_strides.extend([2] * (num_stages - 1)) | |
patch_paddings = [3] # [3, 1, 1, ..] | |
patch_paddings.extend([1] * (num_stages - 1)) | |
in_channels = 1 | |
self.stages = nn.ModuleList() | |
for i, num_block in enumerate(num_blocks): | |
patch_embed = PatchEmbed( | |
in_channels=in_channels, | |
embed_dim=self.embed_dims[i], | |
kernel_size=patch_kernel_sizes[i], | |
stride=patch_strides[i], | |
padding=patch_paddings[i], | |
) | |
blocks = nn.ModuleList( | |
[ | |
TransformerBlock( | |
embed_dim=self.embed_dims[i], | |
num_heads=num_heads[i], | |
ffn_channels=mlp_ratio * self.embed_dims[i], | |
dropout=dropout, | |
sr_ratio=sr_ratios[i], | |
) | |
for _ in range(num_block) | |
] | |
) | |
in_channels = self.embed_dims[i] | |
norm = nn.LayerNorm(self.embed_dims[i]) | |
self.stages.append(nn.ModuleList([patch_embed, blocks, norm])) | |
def forward(self, x): | |
outs = [] | |
for stage in self.stages: | |
x = stage[0](x) # patch embed | |
for block in stage[1]: # transformer blocks | |
x = block(x) | |
x = stage[2](x) # norm | |
x = x.transpose(1, 2) | |
outs.append(x) | |
return outs | |
class SegFormer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
embed_dim = int(config.embed_dim) | |
num_blocks = config.num_blocks | |
num_heads = config.num_heads | |
assert len(num_heads) == 3 and num_heads[2] in ["floor", "ceil"] | |
sr_ratios = config.sr_ratios | |
assert len(sr_ratios) == 3 and sr_ratios[2] in ["floor", "ceil"] | |
mlp_ratio = int(config.mlp_ratio) | |
dropout = float(config.dropout) | |
decoder_channels = int(config.decoder_channels) | |
self.interpolate_mode = str(config.interpolate_mode) | |
output_size = int(config.output_size) | |
self.MiT = MiT(embed_dim, num_blocks, num_heads, sr_ratios, mlp_ratio, dropout) | |
num_stages = len(num_blocks) | |
self.decode_mlps = nn.ModuleList( | |
[ | |
nn.Conv1d(self.MiT.embed_dims[i], decoder_channels, 1, bias=False) | |
for i in range(num_stages) | |
] | |
) | |
self.decode_fusion = nn.Conv1d( | |
decoder_channels * num_stages, decoder_channels, 1, bias=False | |
) | |
self.cls = nn.Conv1d(decoder_channels, output_size, 1, bias=False) | |
def forward(self, input: torch.Tensor, y=None): | |
output = input | |
output = self.MiT(output) | |
for i, (_output, decode_mlp) in enumerate(zip(output, self.decode_mlps)): | |
_output = decode_mlp(_output) | |
if i != 0: | |
_output = F.interpolate( | |
_output, size=output[0].shape[2], mode=self.interpolate_mode | |
) | |
output[i] = _output | |
output = torch.concat(output, dim=1) | |
output = self.decode_fusion(output) | |
output = self.cls(output) | |
return F.interpolate(output, size=input.shape[2], mode=self.interpolate_mode) | |