""" paper: https://arxiv.org/abs/2012.15840 - ref - encoder: - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py - https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py - decoder: - https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_up_head.py - https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_mla_head.py - encoder: ViT 와 구조가 동일하며, PatchEmbed 의 경우 patch_size를 kernel_size와 stride 로 하는 Conv1d를 사용 - decoder: upsample 하는 방식으로 다음 두가지를 사용 (scale_factor: 특정 배수만큼 upsample / size: 특정 크기와 동일한 크기로 upsample) - naive: 원본 길이로 size 방식 upsample - pup: scale_factor 방식으로 수행하다가 마지막에 원본 길이로 size 방식으로 upsample - mla: 총 두 단계로 수행하며, 첫번째 단계에서 transformer block 의 결과들을 scale_factor 방식으로 수행하고 두번째 단계에서 첫번째 결과들을 concat 한 후 size 방식으로 upsample """ import math import torch from torch import nn from einops import rearrange class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) # ========== 여기까지 https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py 차용 ========== # ========== 아래부터 setr 원본 참고 https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py ========== class TransformerBlock(nn.Module): def __init__( self, dim, num_attn_heads, attn_head_dim, mlp_dim, attn_dropout=0.0, ffn_dropout=0.0, ): super().__init__() self.attn = Attention( dim, heads=num_attn_heads, dim_head=attn_head_dim, dropout=attn_dropout ) self.ffn = FeedForward(dim, mlp_dim, dropout=ffn_dropout) def forward(self, x): x = self.attn(x) + x x = self.ffn(x) + x return x class PatchEmbed(nn.Module): def __init__( self, embed_dim=1024, kernel_size=16, bias=False, ): super().__init__() self.projection = nn.Conv1d( in_channels=1, out_channels=embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias, ) def forward(self, x: torch.Tensor): return self.projection(x).transpose(1, 2) class SETR(nn.Module): def __init__(self, config): super().__init__() embed_dim = int(config.embed_dim) data_len = int(config.data_len) # ECGPQRSTDataset.second, hz 에 맞춰서 patch_size = int(config.patch_size) assert data_len % patch_size == 0 num_patches = data_len // patch_size patch_bias = bool(config.patch_bias) dropout = float(config.dropout) # pos_dropout_p: float = config.pos_dropout_p # 파라미터라 너무 많으므로 우선 dropout 개수는 하나로 사용 num_layers = int(config.num_layers) # transformer block 개수 num_attn_heads = int(config.num_attn_heads) attn_head_dim = int(config.attn_head_dim) mlp_dim = int(config.mlp_dim) # attn_dropout: float = config.attn_dropout # ffn_dropout: float = config.ffn_dropout interpolate_mode = str(config.interpolate_mode) dec_conf: dict = config.dec_conf assert len(dec_conf) == 1 self.dec_mode: str = list(dec_conf.keys())[0] assert self.dec_mode in ["naive", "pup", "mla"] self.dec_param: dict = dec_conf[self.dec_mode] output_size = int(config.output_size) # patch embedding self.patch_embed = PatchEmbed( embed_dim=embed_dim, kernel_size=patch_size, bias=patch_bias, ) # positional embedding self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim)) self.pos_dropout = nn.Dropout(p=dropout) # transformer encoder self.layers = nn.ModuleList() for _ in range(num_layers): self.layers.append( TransformerBlock( dim=embed_dim, num_attn_heads=num_attn_heads, attn_head_dim=attn_head_dim, mlp_dim=mlp_dim, attn_dropout=dropout, ffn_dropout=dropout, ) ) # decoder self.dec_layers = nn.ModuleList() if self.dec_mode == "naive": self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode)) dec_out_channel = embed_dim elif self.dec_mode == "pup": self.dec_layers.append(nn.LayerNorm(embed_dim)) dec_up_scale = int(self.dec_param["up_scale"]) available_up_count = int( math.log(data_len // num_patches, dec_up_scale) ) # scale_factor 방법으로 upsample 할 수 있는 단계 계산, 나머지는 size 방법으로 upsample pup_channels = int(self.dec_param["channels"]) dec_in_channel = embed_dim dec_out_channel = pup_channels dec_kernel_size = int(self.dec_param["kernel_size"]) dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"]) assert dec_kernel_size in [1, 3] # 원본 코드 그대로 for i in range(available_up_count + 1): for _ in range(dec_num_convs_by_layer): self.dec_layers.append( nn.Conv1d( dec_in_channel, dec_out_channel, kernel_size=dec_kernel_size, stride=1, padding=(dec_kernel_size - 1) // 2, ) ) dec_in_channel = dec_out_channel if i < available_up_count: self.dec_layers.append( nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode) ) else: # last upsample self.dec_layers.append( nn.Upsample(size=data_len, mode=interpolate_mode) ) else: # mla dec_up_scale = int(self.dec_param["up_scale"]) assert ( data_len >= dec_up_scale * num_patches ) # transformer 중간 결과를 up_scale 만큼 upsample 했을 때 원본 보다는 작아야 최종 upsample 이 의미가 있음 dec_output_step = int(self.dec_param["output_step"]) assert num_layers % dec_output_step == 0 dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"]) dec_kernel_size = int(self.dec_param["kernel_size"]) mid_feature_cnt = num_layers // dec_output_step mla_channel = int(self.dec_param["channels"]) for _ in range(mid_feature_cnt): # transformer block 중간 결과에서 각 step 별로 추출한 feature map 에 적용할 conv-upsample dec_in_channel = embed_dim dec_layers_each_upsample = [] for _ in range(dec_num_convs_by_layer): dec_layers_each_upsample.append( nn.Conv1d( dec_in_channel, mla_channel, kernel_size=dec_kernel_size, stride=1, padding=(dec_kernel_size - 1) // 2, ) ) dec_in_channel = mla_channel dec_layers_each_upsample.append( nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode) ) self.dec_layers.append(nn.Sequential(*dec_layers_each_upsample)) # last decoder layer: 중간 feature map 을 concat 한 이후, upsample self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode)) dec_out_channel = ( mla_channel * mid_feature_cnt ) # self.dec_layers 를 transformer 중간 결과들에 적용한 feature map 개수(mid_feature_cnt)만큼 channel-wise concat 하기 때문에 그만큼 증가된 channel 을 아래 self.cls 의 in_channel 로 사용되어어야 함 self.cls = nn.Conv1d(dec_out_channel, output_size, 1, bias=False) def forward(self, input: torch.Tensor, y=None): output = input # patch embedding output = self.patch_embed(output) # positional embedding output += self.pos_embed output = self.pos_dropout(output) outputs = [] # transformer encoder for i, layer in enumerate(self.layers): output = layer(output) if self.dec_mode == "mla": if (i + 1) % int(self.dec_param["output_step"]) == 0: outputs.append(output.transpose(1, 2)) if self.dec_mode != "mla": # mla 의 경우 위에서 이미 추가 outputs.append(output.transpose(1, 2)) # decoder if self.dec_mode == "naive": assert len(outputs) == 1 output = outputs[0] output = self.dec_layers[0](output) elif self.dec_mode == "pup": assert len(outputs) == 1 output = outputs[0] pup_norm = self.dec_layers[0] output = pup_norm(output.transpose(1, 2)).transpose(1, 2) for i, dec_layer in enumerate(self.dec_layers[1:]): output = dec_layer(output) else: # mla dec_output_step = int(self.dec_param["output_step"]) mid_feature_cnt = len(self.layers) // dec_output_step assert len(outputs) == mid_feature_cnt for i in range(len(outputs)): outputs[i] = self.dec_layers[i](outputs[i]) output = torch.cat(outputs, dim=1) output = self.dec_layers[-1](output) return self.cls(output)