wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/2012.15840
- ref
- encoder:
- https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py
- https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py
- decoder:
- https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_up_head.py
- https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_mla_head.py
- encoder: ViT ์™€ ๊ตฌ์กฐ๊ฐ€ ๋™์ผํ•˜๋ฉฐ, PatchEmbed ์˜ ๊ฒฝ์šฐ patch_size๋ฅผ kernel_size์™€ stride ๋กœ ํ•˜๋Š” Conv1d๋ฅผ ์‚ฌ์šฉ
- decoder: upsample ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋‹ค์Œ ๋‘๊ฐ€์ง€๋ฅผ ์‚ฌ์šฉ (scale_factor: ํŠน์ • ๋ฐฐ์ˆ˜๋งŒํผ upsample / size: ํŠน์ • ํฌ๊ธฐ์™€ ๋™์ผํ•œ ํฌ๊ธฐ๋กœ upsample)
- naive: ์›๋ณธ ๊ธธ์ด๋กœ size ๋ฐฉ์‹ upsample
- pup: scale_factor ๋ฐฉ์‹์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๋‹ค๊ฐ€ ๋งˆ์ง€๋ง‰์— ์›๋ณธ ๊ธธ์ด๋กœ size ๋ฐฉ์‹์œผ๋กœ upsample
- mla: ์ด ๋‘ ๋‹จ๊ณ„๋กœ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, ์ฒซ๋ฒˆ์งธ ๋‹จ๊ณ„์—์„œ transformer block ์˜ ๊ฒฐ๊ณผ๋“ค์„ scale_factor ๋ฐฉ์‹์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๊ณ  ๋‘๋ฒˆ์งธ ๋‹จ๊ณ„์—์„œ ์ฒซ๋ฒˆ์งธ ๊ฒฐ๊ณผ๋“ค์„ concat ํ•œ ํ›„ size ๋ฐฉ์‹์œผ๋กœ upsample
"""
import math
import torch
from torch import nn
from einops import rearrange
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# ========== ์—ฌ๊ธฐ๊นŒ์ง€ https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py ์ฐจ์šฉ ==========
# ========== ์•„๋ž˜๋ถ€ํ„ฐ setr ์›๋ณธ ์ฐธ๊ณ  https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py ==========
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attn_heads,
attn_head_dim,
mlp_dim,
attn_dropout=0.0,
ffn_dropout=0.0,
):
super().__init__()
self.attn = Attention(
dim, heads=num_attn_heads, dim_head=attn_head_dim, dropout=attn_dropout
)
self.ffn = FeedForward(dim, mlp_dim, dropout=ffn_dropout)
def forward(self, x):
x = self.attn(x) + x
x = self.ffn(x) + x
return x
class PatchEmbed(nn.Module):
def __init__(
self,
embed_dim=1024,
kernel_size=16,
bias=False,
):
super().__init__()
self.projection = nn.Conv1d(
in_channels=1,
out_channels=embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=bias,
)
def forward(self, x: torch.Tensor):
return self.projection(x).transpose(1, 2)
class SETR(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = int(config.embed_dim)
data_len = int(config.data_len) # ECGPQRSTDataset.second, hz ์— ๋งž์ถฐ์„œ
patch_size = int(config.patch_size)
assert data_len % patch_size == 0
num_patches = data_len // patch_size
patch_bias = bool(config.patch_bias)
dropout = float(config.dropout)
# pos_dropout_p: float = config.pos_dropout_p # ํŒŒ๋ผ๋ฏธํ„ฐ๋ผ ๋„ˆ๋ฌด ๋งŽ์œผ๋ฏ€๋กœ ์šฐ์„  dropout ๊ฐœ์ˆ˜๋Š” ํ•˜๋‚˜๋กœ ์‚ฌ์šฉ
num_layers = int(config.num_layers) # transformer block ๊ฐœ์ˆ˜
num_attn_heads = int(config.num_attn_heads)
attn_head_dim = int(config.attn_head_dim)
mlp_dim = int(config.mlp_dim)
# attn_dropout: float = config.attn_dropout
# ffn_dropout: float = config.ffn_dropout
interpolate_mode = str(config.interpolate_mode)
dec_conf: dict = config.dec_conf
assert len(dec_conf) == 1
self.dec_mode: str = list(dec_conf.keys())[0]
assert self.dec_mode in ["naive", "pup", "mla"]
self.dec_param: dict = dec_conf[self.dec_mode]
output_size = int(config.output_size)
# patch embedding
self.patch_embed = PatchEmbed(
embed_dim=embed_dim,
kernel_size=patch_size,
bias=patch_bias,
)
# positional embedding
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
self.pos_dropout = nn.Dropout(p=dropout)
# transformer encoder
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TransformerBlock(
dim=embed_dim,
num_attn_heads=num_attn_heads,
attn_head_dim=attn_head_dim,
mlp_dim=mlp_dim,
attn_dropout=dropout,
ffn_dropout=dropout,
)
)
# decoder
self.dec_layers = nn.ModuleList()
if self.dec_mode == "naive":
self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode))
dec_out_channel = embed_dim
elif self.dec_mode == "pup":
self.dec_layers.append(nn.LayerNorm(embed_dim))
dec_up_scale = int(self.dec_param["up_scale"])
available_up_count = int(
math.log(data_len // num_patches, dec_up_scale)
) # scale_factor ๋ฐฉ๋ฒ•์œผ๋กœ upsample ํ•  ์ˆ˜ ์žˆ๋Š” ๋‹จ๊ณ„ ๊ณ„์‚ฐ, ๋‚˜๋จธ์ง€๋Š” size ๋ฐฉ๋ฒ•์œผ๋กœ upsample
pup_channels = int(self.dec_param["channels"])
dec_in_channel = embed_dim
dec_out_channel = pup_channels
dec_kernel_size = int(self.dec_param["kernel_size"])
dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"])
assert dec_kernel_size in [1, 3] # ์›๋ณธ ์ฝ”๋“œ ๊ทธ๋Œ€๋กœ
for i in range(available_up_count + 1):
for _ in range(dec_num_convs_by_layer):
self.dec_layers.append(
nn.Conv1d(
dec_in_channel,
dec_out_channel,
kernel_size=dec_kernel_size,
stride=1,
padding=(dec_kernel_size - 1) // 2,
)
)
dec_in_channel = dec_out_channel
if i < available_up_count:
self.dec_layers.append(
nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode)
)
else: # last upsample
self.dec_layers.append(
nn.Upsample(size=data_len, mode=interpolate_mode)
)
else: # mla
dec_up_scale = int(self.dec_param["up_scale"])
assert (
data_len >= dec_up_scale * num_patches
) # transformer ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฅผ up_scale ๋งŒํผ upsample ํ–ˆ์„ ๋•Œ ์›๋ณธ ๋ณด๋‹ค๋Š” ์ž‘์•„์•ผ ์ตœ์ข… upsample ์ด ์˜๋ฏธ๊ฐ€ ์žˆ์Œ
dec_output_step = int(self.dec_param["output_step"])
assert num_layers % dec_output_step == 0
dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"])
dec_kernel_size = int(self.dec_param["kernel_size"])
mid_feature_cnt = num_layers // dec_output_step
mla_channel = int(self.dec_param["channels"])
for _ in range(mid_feature_cnt):
# transformer block ์ค‘๊ฐ„ ๊ฒฐ๊ณผ์—์„œ ๊ฐ step ๋ณ„๋กœ ์ถ”์ถœํ•œ feature map ์— ์ ์šฉํ•  conv-upsample
dec_in_channel = embed_dim
dec_layers_each_upsample = []
for _ in range(dec_num_convs_by_layer):
dec_layers_each_upsample.append(
nn.Conv1d(
dec_in_channel,
mla_channel,
kernel_size=dec_kernel_size,
stride=1,
padding=(dec_kernel_size - 1) // 2,
)
)
dec_in_channel = mla_channel
dec_layers_each_upsample.append(
nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode)
)
self.dec_layers.append(nn.Sequential(*dec_layers_each_upsample))
# last decoder layer: ์ค‘๊ฐ„ feature map ์„ concat ํ•œ ์ดํ›„, upsample
self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode))
dec_out_channel = (
mla_channel * mid_feature_cnt
) # self.dec_layers ๋ฅผ transformer ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋“ค์— ์ ์šฉํ•œ feature map ๊ฐœ์ˆ˜(mid_feature_cnt)๋งŒํผ channel-wise concat ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ๋งŒํผ ์ฆ๊ฐ€๋œ channel ์„ ์•„๋ž˜ self.cls ์˜ in_channel ๋กœ ์‚ฌ์šฉ๋˜์–ด์–ด์•ผ ํ•จ
self.cls = nn.Conv1d(dec_out_channel, output_size, 1, bias=False)
def forward(self, input: torch.Tensor, y=None):
output = input
# patch embedding
output = self.patch_embed(output)
# positional embedding
output += self.pos_embed
output = self.pos_dropout(output)
outputs = []
# transformer encoder
for i, layer in enumerate(self.layers):
output = layer(output)
if self.dec_mode == "mla":
if (i + 1) % int(self.dec_param["output_step"]) == 0:
outputs.append(output.transpose(1, 2))
if self.dec_mode != "mla": # mla ์˜ ๊ฒฝ์šฐ ์œ„์—์„œ ์ด๋ฏธ ์ถ”๊ฐ€
outputs.append(output.transpose(1, 2))
# decoder
if self.dec_mode == "naive":
assert len(outputs) == 1
output = outputs[0]
output = self.dec_layers[0](output)
elif self.dec_mode == "pup":
assert len(outputs) == 1
output = outputs[0]
pup_norm = self.dec_layers[0]
output = pup_norm(output.transpose(1, 2)).transpose(1, 2)
for i, dec_layer in enumerate(self.dec_layers[1:]):
output = dec_layer(output)
else: # mla
dec_output_step = int(self.dec_param["output_step"])
mid_feature_cnt = len(self.layers) // dec_output_step
assert len(outputs) == mid_feature_cnt
for i in range(len(outputs)):
outputs[i] = self.dec_layers[i](outputs[i])
output = torch.cat(outputs, dim=1)
output = self.dec_layers[-1](output)
return self.cls(output)