""" paper: https://arxiv.org/abs/2105.15203 - ref: - encoder: - https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/mit.py - decoder: - https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py """ import torch from torch import nn from torch.functional import F import math from einops import rearrange class MixFFN(nn.Module): def __init__(self, embed_dim, channels, dropout=0.0): super().__init__() self.layers = nn.Sequential( nn.Conv1d( # fc1 in_channels=embed_dim, out_channels=channels, kernel_size=1, stride=1 ), nn.Conv1d( # position embed (depthwise-separable) in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, groups=channels, ), nn.GELU(), nn.Dropout(dropout), nn.Conv1d( # fc2 in_channels=channels, out_channels=embed_dim, kernel_size=1 ), nn.Dropout(dropout), ) def forward(self, x): out = x.transpose(1, 2) out = self.layers(out) out = out.transpose(1, 2) return out class EfficientMultiheadAttention(nn.Module): """ PVT(Pyramid Vision Transformer)에서 사용한 Spatial-Reduction Attention 을 차용 변수명 중 sr 은 Spatial-Reduction 의 약어 """ def __init__( self, embed_dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1 ): super().__init__() assert ( embed_dim % num_heads == 0 ), f"dim {embed_dim} should be divided by num_heads {num_heads}." self.num_heads = num_heads head_dim = embed_dim // num_heads self.scale = head_dim**-0.5 self.q = nn.Linear(embed_dim, embed_dim, bias=False) self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=False) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(embed_dim, embed_dim) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv1d( embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio ) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): B, N, C = x.shape q = self.q(x) q = rearrange(q, "b n (h c) -> b h n c", h=self.num_heads) if self.sr_ratio > 1: x_ = x.transpose(1, 2) x_ = self.sr(x_).transpose(1, 2) x_ = self.norm(x_) kv = self.kv(x_) kv = rearrange( kv, "b n (two_heads h c) -> two_heads b h n c", two_heads=2, h=self.num_heads, ) else: kv = self.kv(x) kv = rearrange( kv, "b n (two_heads h c) -> two_heads b h n c", two_heads=2, h=self.num_heads, ) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2) x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, ffn_channels, dropout=0.2, sr_ratio=1): super().__init__() self.attn = nn.Sequential( nn.LayerNorm(embed_dim), EfficientMultiheadAttention( embed_dim=embed_dim, num_heads=num_heads, attn_drop=dropout, proj_drop=dropout, sr_ratio=sr_ratio, ), ) self.ffn = nn.Sequential( nn.LayerNorm(embed_dim), MixFFN(embed_dim=embed_dim, channels=ffn_channels, dropout=dropout), ) def forward(self, x): x = x + self.attn(x) x = x + self.ffn(x) return x class PatchEmbed(nn.Module): def __init__( self, in_channels=1, embed_dim=1024, kernel_size=7, stride=4, padding=3, bias=False, ): super().__init__() self.projection = nn.Conv1d( in_channels=in_channels, out_channels=embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) def forward(self, x: torch.Tensor): return self.projection(x).transpose(1, 2) class MiT(nn.Module): """MixVisionTransformer""" def __init__( self, embed_dim=512, num_blocks=[2, 2, 6, 2], num_heads=[1, 2, "ceil"], sr_ratios=[1, 2, "ceil"], mlp_ratio=4, dropout=0.2, ): super().__init__() num_stages = len(num_blocks) round_func = getattr(math, num_heads[2]) # math.ceil or match.floor num_heads = [ round_func((num_heads[0] * math.pow(num_heads[1], itr))) for itr in range(num_stages) ] round_func = getattr(math, sr_ratios[2]) # math.ceil or match.floor sr_ratios = [ round_func(sr_ratios[0] * math.pow(sr_ratios[1], itr)) for itr in range(num_stages) ] sr_ratios.reverse() self.embed_dims = [embed_dim * num_head for num_head in num_heads] patch_kernel_sizes = [7] # [7, 3, 3, ..] patch_kernel_sizes.extend([3] * (num_stages - 1)) patch_strides = [4] # [4, 2, 2, ..] patch_strides.extend([2] * (num_stages - 1)) patch_paddings = [3] # [3, 1, 1, ..] patch_paddings.extend([1] * (num_stages - 1)) in_channels = 1 self.stages = nn.ModuleList() for i, num_block in enumerate(num_blocks): patch_embed = PatchEmbed( in_channels=in_channels, embed_dim=self.embed_dims[i], kernel_size=patch_kernel_sizes[i], stride=patch_strides[i], padding=patch_paddings[i], ) blocks = nn.ModuleList( [ TransformerBlock( embed_dim=self.embed_dims[i], num_heads=num_heads[i], ffn_channels=mlp_ratio * self.embed_dims[i], dropout=dropout, sr_ratio=sr_ratios[i], ) for _ in range(num_block) ] ) in_channels = self.embed_dims[i] norm = nn.LayerNorm(self.embed_dims[i]) self.stages.append(nn.ModuleList([patch_embed, blocks, norm])) def forward(self, x): outs = [] for stage in self.stages: x = stage[0](x) # patch embed for block in stage[1]: # transformer blocks x = block(x) x = stage[2](x) # norm x = x.transpose(1, 2) outs.append(x) return outs class SegFormer(nn.Module): def __init__(self, config): super().__init__() embed_dim = int(config.embed_dim) num_blocks = config.num_blocks num_heads = config.num_heads assert len(num_heads) == 3 and num_heads[2] in ["floor", "ceil"] sr_ratios = config.sr_ratios assert len(sr_ratios) == 3 and sr_ratios[2] in ["floor", "ceil"] mlp_ratio = int(config.mlp_ratio) dropout = float(config.dropout) decoder_channels = int(config.decoder_channels) self.interpolate_mode = str(config.interpolate_mode) output_size = int(config.output_size) self.MiT = MiT(embed_dim, num_blocks, num_heads, sr_ratios, mlp_ratio, dropout) num_stages = len(num_blocks) self.decode_mlps = nn.ModuleList( [ nn.Conv1d(self.MiT.embed_dims[i], decoder_channels, 1, bias=False) for i in range(num_stages) ] ) self.decode_fusion = nn.Conv1d( decoder_channels * num_stages, decoder_channels, 1, bias=False ) self.cls = nn.Conv1d(decoder_channels, output_size, 1, bias=False) def forward(self, input: torch.Tensor, y=None): output = input output = self.MiT(output) for i, (_output, decode_mlp) in enumerate(zip(output, self.decode_mlps)): _output = decode_mlp(_output) if i != 0: _output = F.interpolate( _output, size=output[0].shape[2], mode=self.interpolate_mode ) output[i] = _output output = torch.concat(output, dim=1) output = self.decode_fusion(output) output = self.cls(output) return F.interpolate(output, size=input.shape[2], mode=self.interpolate_mode)