Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/2012.15840 | |
- ref | |
- encoder: | |
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py | |
- https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py | |
- decoder: | |
- https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_up_head.py | |
- https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_mla_head.py | |
- encoder: ViT ์ ๊ตฌ์กฐ๊ฐ ๋์ผํ๋ฉฐ, PatchEmbed ์ ๊ฒฝ์ฐ patch_size๋ฅผ kernel_size์ stride ๋ก ํ๋ Conv1d๋ฅผ ์ฌ์ฉ | |
- decoder: upsample ํ๋ ๋ฐฉ์์ผ๋ก ๋ค์ ๋๊ฐ์ง๋ฅผ ์ฌ์ฉ (scale_factor: ํน์ ๋ฐฐ์๋งํผ upsample / size: ํน์ ํฌ๊ธฐ์ ๋์ผํ ํฌ๊ธฐ๋ก upsample) | |
- naive: ์๋ณธ ๊ธธ์ด๋ก size ๋ฐฉ์ upsample | |
- pup: scale_factor ๋ฐฉ์์ผ๋ก ์ํํ๋ค๊ฐ ๋ง์ง๋ง์ ์๋ณธ ๊ธธ์ด๋ก size ๋ฐฉ์์ผ๋ก upsample | |
- mla: ์ด ๋ ๋จ๊ณ๋ก ์ํํ๋ฉฐ, ์ฒซ๋ฒ์งธ ๋จ๊ณ์์ transformer block ์ ๊ฒฐ๊ณผ๋ค์ scale_factor ๋ฐฉ์์ผ๋ก ์ํํ๊ณ ๋๋ฒ์งธ ๋จ๊ณ์์ ์ฒซ๋ฒ์งธ ๊ฒฐ๊ณผ๋ค์ concat ํ ํ size ๋ฐฉ์์ผ๋ก upsample | |
""" | |
import math | |
import torch | |
from torch import nn | |
from einops import rearrange | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout=0.0): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim, dim), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Attention(nn.Module): | |
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): | |
super().__init__() | |
inner_dim = dim_head * heads | |
project_out = not (heads == 1 and dim_head == dim) | |
self.heads = heads | |
self.scale = dim_head**-0.5 | |
self.norm = nn.LayerNorm(dim) | |
self.attend = nn.Softmax(dim=-1) | |
self.dropout = nn.Dropout(dropout) | |
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
self.to_out = ( | |
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) | |
if project_out | |
else nn.Identity() | |
) | |
def forward(self, x): | |
x = self.norm(x) | |
qkv = self.to_qkv(x).chunk(3, dim=-1) | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) | |
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
attn = self.attend(dots) | |
attn = self.dropout(attn) | |
out = torch.matmul(attn, v) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
return self.to_out(out) | |
# ========== ์ฌ๊ธฐ๊น์ง https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py ์ฐจ์ฉ ========== | |
# ========== ์๋๋ถํฐ setr ์๋ณธ ์ฐธ๊ณ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py ========== | |
class TransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_attn_heads, | |
attn_head_dim, | |
mlp_dim, | |
attn_dropout=0.0, | |
ffn_dropout=0.0, | |
): | |
super().__init__() | |
self.attn = Attention( | |
dim, heads=num_attn_heads, dim_head=attn_head_dim, dropout=attn_dropout | |
) | |
self.ffn = FeedForward(dim, mlp_dim, dropout=ffn_dropout) | |
def forward(self, x): | |
x = self.attn(x) + x | |
x = self.ffn(x) + x | |
return x | |
class PatchEmbed(nn.Module): | |
def __init__( | |
self, | |
embed_dim=1024, | |
kernel_size=16, | |
bias=False, | |
): | |
super().__init__() | |
self.projection = nn.Conv1d( | |
in_channels=1, | |
out_channels=embed_dim, | |
kernel_size=kernel_size, | |
stride=kernel_size, | |
bias=bias, | |
) | |
def forward(self, x: torch.Tensor): | |
return self.projection(x).transpose(1, 2) | |
class SETR(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
embed_dim = int(config.embed_dim) | |
data_len = int(config.data_len) # ECGPQRSTDataset.second, hz ์ ๋ง์ถฐ์ | |
patch_size = int(config.patch_size) | |
assert data_len % patch_size == 0 | |
num_patches = data_len // patch_size | |
patch_bias = bool(config.patch_bias) | |
dropout = float(config.dropout) | |
# pos_dropout_p: float = config.pos_dropout_p # ํ๋ผ๋ฏธํฐ๋ผ ๋๋ฌด ๋ง์ผ๋ฏ๋ก ์ฐ์ dropout ๊ฐ์๋ ํ๋๋ก ์ฌ์ฉ | |
num_layers = int(config.num_layers) # transformer block ๊ฐ์ | |
num_attn_heads = int(config.num_attn_heads) | |
attn_head_dim = int(config.attn_head_dim) | |
mlp_dim = int(config.mlp_dim) | |
# attn_dropout: float = config.attn_dropout | |
# ffn_dropout: float = config.ffn_dropout | |
interpolate_mode = str(config.interpolate_mode) | |
dec_conf: dict = config.dec_conf | |
assert len(dec_conf) == 1 | |
self.dec_mode: str = list(dec_conf.keys())[0] | |
assert self.dec_mode in ["naive", "pup", "mla"] | |
self.dec_param: dict = dec_conf[self.dec_mode] | |
output_size = int(config.output_size) | |
# patch embedding | |
self.patch_embed = PatchEmbed( | |
embed_dim=embed_dim, | |
kernel_size=patch_size, | |
bias=patch_bias, | |
) | |
# positional embedding | |
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim)) | |
self.pos_dropout = nn.Dropout(p=dropout) | |
# transformer encoder | |
self.layers = nn.ModuleList() | |
for _ in range(num_layers): | |
self.layers.append( | |
TransformerBlock( | |
dim=embed_dim, | |
num_attn_heads=num_attn_heads, | |
attn_head_dim=attn_head_dim, | |
mlp_dim=mlp_dim, | |
attn_dropout=dropout, | |
ffn_dropout=dropout, | |
) | |
) | |
# decoder | |
self.dec_layers = nn.ModuleList() | |
if self.dec_mode == "naive": | |
self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode)) | |
dec_out_channel = embed_dim | |
elif self.dec_mode == "pup": | |
self.dec_layers.append(nn.LayerNorm(embed_dim)) | |
dec_up_scale = int(self.dec_param["up_scale"]) | |
available_up_count = int( | |
math.log(data_len // num_patches, dec_up_scale) | |
) # scale_factor ๋ฐฉ๋ฒ์ผ๋ก upsample ํ ์ ์๋ ๋จ๊ณ ๊ณ์ฐ, ๋๋จธ์ง๋ size ๋ฐฉ๋ฒ์ผ๋ก upsample | |
pup_channels = int(self.dec_param["channels"]) | |
dec_in_channel = embed_dim | |
dec_out_channel = pup_channels | |
dec_kernel_size = int(self.dec_param["kernel_size"]) | |
dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"]) | |
assert dec_kernel_size in [1, 3] # ์๋ณธ ์ฝ๋ ๊ทธ๋๋ก | |
for i in range(available_up_count + 1): | |
for _ in range(dec_num_convs_by_layer): | |
self.dec_layers.append( | |
nn.Conv1d( | |
dec_in_channel, | |
dec_out_channel, | |
kernel_size=dec_kernel_size, | |
stride=1, | |
padding=(dec_kernel_size - 1) // 2, | |
) | |
) | |
dec_in_channel = dec_out_channel | |
if i < available_up_count: | |
self.dec_layers.append( | |
nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode) | |
) | |
else: # last upsample | |
self.dec_layers.append( | |
nn.Upsample(size=data_len, mode=interpolate_mode) | |
) | |
else: # mla | |
dec_up_scale = int(self.dec_param["up_scale"]) | |
assert ( | |
data_len >= dec_up_scale * num_patches | |
) # transformer ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ up_scale ๋งํผ upsample ํ์ ๋ ์๋ณธ ๋ณด๋ค๋ ์์์ผ ์ต์ข upsample ์ด ์๋ฏธ๊ฐ ์์ | |
dec_output_step = int(self.dec_param["output_step"]) | |
assert num_layers % dec_output_step == 0 | |
dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"]) | |
dec_kernel_size = int(self.dec_param["kernel_size"]) | |
mid_feature_cnt = num_layers // dec_output_step | |
mla_channel = int(self.dec_param["channels"]) | |
for _ in range(mid_feature_cnt): | |
# transformer block ์ค๊ฐ ๊ฒฐ๊ณผ์์ ๊ฐ step ๋ณ๋ก ์ถ์ถํ feature map ์ ์ ์ฉํ conv-upsample | |
dec_in_channel = embed_dim | |
dec_layers_each_upsample = [] | |
for _ in range(dec_num_convs_by_layer): | |
dec_layers_each_upsample.append( | |
nn.Conv1d( | |
dec_in_channel, | |
mla_channel, | |
kernel_size=dec_kernel_size, | |
stride=1, | |
padding=(dec_kernel_size - 1) // 2, | |
) | |
) | |
dec_in_channel = mla_channel | |
dec_layers_each_upsample.append( | |
nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode) | |
) | |
self.dec_layers.append(nn.Sequential(*dec_layers_each_upsample)) | |
# last decoder layer: ์ค๊ฐ feature map ์ concat ํ ์ดํ, upsample | |
self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode)) | |
dec_out_channel = ( | |
mla_channel * mid_feature_cnt | |
) # self.dec_layers ๋ฅผ transformer ์ค๊ฐ ๊ฒฐ๊ณผ๋ค์ ์ ์ฉํ feature map ๊ฐ์(mid_feature_cnt)๋งํผ channel-wise concat ํ๊ธฐ ๋๋ฌธ์ ๊ทธ๋งํผ ์ฆ๊ฐ๋ channel ์ ์๋ self.cls ์ in_channel ๋ก ์ฌ์ฉ๋์ด์ด์ผ ํจ | |
self.cls = nn.Conv1d(dec_out_channel, output_size, 1, bias=False) | |
def forward(self, input: torch.Tensor, y=None): | |
output = input | |
# patch embedding | |
output = self.patch_embed(output) | |
# positional embedding | |
output += self.pos_embed | |
output = self.pos_dropout(output) | |
outputs = [] | |
# transformer encoder | |
for i, layer in enumerate(self.layers): | |
output = layer(output) | |
if self.dec_mode == "mla": | |
if (i + 1) % int(self.dec_param["output_step"]) == 0: | |
outputs.append(output.transpose(1, 2)) | |
if self.dec_mode != "mla": # mla ์ ๊ฒฝ์ฐ ์์์ ์ด๋ฏธ ์ถ๊ฐ | |
outputs.append(output.transpose(1, 2)) | |
# decoder | |
if self.dec_mode == "naive": | |
assert len(outputs) == 1 | |
output = outputs[0] | |
output = self.dec_layers[0](output) | |
elif self.dec_mode == "pup": | |
assert len(outputs) == 1 | |
output = outputs[0] | |
pup_norm = self.dec_layers[0] | |
output = pup_norm(output.transpose(1, 2)).transpose(1, 2) | |
for i, dec_layer in enumerate(self.dec_layers[1:]): | |
output = dec_layer(output) | |
else: # mla | |
dec_output_step = int(self.dec_param["output_step"]) | |
mid_feature_cnt = len(self.layers) // dec_output_step | |
assert len(outputs) == mid_feature_cnt | |
for i in range(len(outputs)): | |
outputs[i] = self.dec_layers[i](outputs[i]) | |
output = torch.cat(outputs, dim=1) | |
output = self.dec_layers[-1](output) | |
return self.cls(output) | |