ECG_Delineation / res /impl /HRNetV2.py
wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/1904.04514
ref: https://github.com/HRNet/HRNet-Semantic-Segmentation/blob/HRNet-OCR/lib/models/seg_hrnet.py
"""
import torch
import torch.nn as nn
from torch.functional import F
import math
def _gen_same_length_conv(in_channel, out_channel, kernel_size=1, dilation=1):
"""๊ธธ์ด๊ฐ€ ๋ณ€ํ•˜์ง€ ์•Š๋Š” conv ์ƒ์„ฑ, block ๋‚ด์—์„œ feature ๋ฅผ ์ถ”์ถœํ•˜๋Š” convolution ์—์„œ ์‚ฌ์šฉ"""
return nn.Conv1d(
in_channel,
out_channel,
kernel_size=kernel_size,
stride=1,
padding=(dilation * (kernel_size - 1)) // 2,
dilation=dilation,
bias=False,
)
def _gen_downsample(in_channel, out_channel):
"""kernel_size:3, stride:2, padding:1 ์ธ 2๋ฐฐ downsample ํ•˜๋Š” conv ์ƒ์„ฑ"""
return nn.Conv1d(
in_channel, out_channel, kernel_size=3, stride=2, padding=1, bias=False
)
def _gen_channel_change_conv(in_channel, out_channel):
"""kernel_size:1, stride:1 ์ธ channel ๋ณ€๊ฒฝํ•˜๋Š” conv ์ƒ์„ฑ"""
return nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, bias=False)
class BasicBlock(nn.Module):
"""resnet ์˜ basic block ์œผ๋กœ channel ๋ณ€ํ™”๋Š” inplanes -> planes"""
expansion = 1
def __init__(self, inplanes, planes, kernel_size=3, dilation=1):
super().__init__()
self.conv1 = _gen_same_length_conv(inplanes, planes, kernel_size, dilation)
self.bn1 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU()
self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation)
self.bn2 = nn.BatchNorm1d(planes)
self.make_residual = (
_gen_channel_change_conv(inplanes, planes)
if inplanes != planes
else nn.Identity()
)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
residual = self.make_residual(x)
out = out + residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""resnet ์˜ Bottleneck block ์œผ๋กœ channel ๋ณ€ํ™”๋Š” inplanes -> planes * 4"""
expansion = 4
def __init__(self, inplanes, planes, kernel_size=3, dilation=1):
super().__init__()
self.conv1 = _gen_same_length_conv(inplanes, planes)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation)
self.bn2 = nn.BatchNorm1d(planes)
self.conv3 = _gen_same_length_conv(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm1d(planes * self.expansion)
self.relu = nn.ReLU()
self.make_residual = (
_gen_channel_change_conv(inplanes, planes * self.expansion)
if inplanes != planes * self.expansion
else nn.Identity()
)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
residual = self.make_residual(x)
out = out + residual
out = self.relu(out)
return out
class HRModule(nn.Module):
def __init__(
self,
stage_idx,
num_blocks,
block_type_by_stage,
in_channels_by_stage,
out_channels_by_stage,
data_len_by_branch,
kernel_size,
dilation,
interpolate_mode,
):
super().__init__()
self.branches = nn.ModuleList()
self.fusions = nn.ModuleList()
block_type: BasicBlock | Bottleneck = block_type_by_stage[stage_idx]
in_channels = in_channels_by_stage[stage_idx]
for i in range(stage_idx + 1): # branch ์ƒ์„ฑ
blocks_by_branch = []
_channels = in_channels[i]
blocks_by_branch.append(
block_type(_channels, _channels, kernel_size, dilation)
)
for _ in range(1, num_blocks):
blocks_by_branch.append(
block_type(
_channels * block_type.expansion,
_channels,
kernel_size,
dilation,
)
)
self.branches.append(nn.Sequential(*blocks_by_branch))
out_channels = out_channels_by_stage[stage_idx]
for i in range(stage_idx + 1):
fusion_by_branch = nn.ModuleList()
for j in range(stage_idx + 1):
if i < j:
fusion_by_branch.append(
nn.Sequential(
_gen_channel_change_conv(out_channels[j], in_channels[i]),
nn.BatchNorm1d(in_channels[i]),
nn.Upsample(
size=data_len_by_branch[i], mode=interpolate_mode
),
)
)
elif i == j:
if out_channels[i] != in_channels[j]:
fusion_by_branch.append(
nn.Sequential(
_gen_channel_change_conv(
out_channels[i], in_channels[j]
),
nn.BatchNorm1d(in_channels[j]),
nn.ReLU(),
)
)
else:
fusion_by_branch.append(nn.Identity())
else:
# ์ฐจ์ด๋‚˜๋Š” branch ๋งŒํผ 2๋ฐฐ์”ฉ downsample, channel ์€ ํ˜„์žฌ layer ์˜ in_channel ๋กœ ๋งž์ถฐ์คŒ
downsamples = [
_gen_downsample(out_channels[j], in_channels[i]),
nn.BatchNorm1d(in_channels[i]),
]
for _ in range(1, i - j):
downsamples.extend(
[
nn.ReLU(),
_gen_downsample(in_channels[i], in_channels[i]),
nn.BatchNorm1d(in_channels[i]),
]
)
fusion_by_branch.append(nn.Sequential(*downsamples))
self.fusions.append(fusion_by_branch)
class HRNetV2(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
data_len = int(config.data_len) # ECGPQRSTDataset.second, hz ์— ๋งž์ถฐ์„œ
kernel_size = int(config.kernel_size)
dilation = int(config.dilation)
num_stages = int(config.num_stages)
num_blocks = int(config.num_blocks)
self.num_modules = config.num_modules # [1, 1, 4, 3, ..]
assert num_stages <= len(self.num_modules)
use_bottleneck = config.use_bottleneck # [1, 0, 0, 0, ..]
assert num_stages <= len(use_bottleneck)
stage1_channels = int(config.stage1_channels) # 64, 128
num_channels_init = int(config.num_channels_init) # 18, 32, 48
self.interpolate_mode = config.interpolate_mode
output_size = config.output_size # 3(p, qrs, t)
# stem
self.stem = nn.Sequential(
nn.Conv1d(
1, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False
),
nn.BatchNorm1d(stage1_channels),
nn.Conv1d(
stage1_channels,
stage1_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm1d(stage1_channels),
nn.ReLU(),
)
for _ in range(2): # stem ์„ ๊ฑฐ์นœ ์ดํ›„ ๋ฐ์ดํ„ฐ ๊ธธ์ด ๊ณ„์‚ฐ
data_len = math.floor((data_len - 1) / 2 + 1)
# create meta: ๋„คํŠธ์›Œํฌ ์ƒ์„ฑ ์ „ ๊ฐ stage ์˜ in_channel, out_channel ๋“ฑ์˜ ์ •๋ณด๋ฅผ ๋จผ์ € ๋งŒ๋“ค๊ณ  ์‹œ์ž‘
in_channels_by_stage = []
out_channels_by_stage = []
block_type_by_stage = []
for stage_idx in range(num_stages):
block_type_each_stage = (
Bottleneck if use_bottleneck[stage_idx] == 1 else BasicBlock
)
if stage_idx == 0:
in_channels_each_stage = [stage1_channels]
out_channels_each_stage = [
stage1_channels * block_type_each_stage.expansion
]
data_len_by_branch = [data_len]
else:
in_channels_each_stage = [
num_channels_init * 2**idx for idx in range(stage_idx + 1)
]
out_channels_each_stage = [
(num_channels_init * 2**idx) * block_type_each_stage.expansion
for idx in range(stage_idx + 1)
]
data_len_by_branch.append(
math.floor((data_len_by_branch[-1] - 1) / 2 + 1)
)
block_type_by_stage.append(block_type_each_stage)
in_channels_by_stage.append(in_channels_each_stage)
out_channels_by_stage.append(out_channels_each_stage)
# create stages
self.stages = nn.ModuleList()
for stage_idx in range(num_stages):
modules_by_stage = nn.ModuleList()
for _ in range(self.num_modules[stage_idx]):
modules_by_stage.append(
HRModule(
stage_idx,
num_blocks,
block_type_by_stage,
in_channels_by_stage,
out_channels_by_stage,
data_len_by_branch,
kernel_size,
dilation,
self.interpolate_mode,
)
)
self.stages.append(modules_by_stage)
# create transition
self.transitions = nn.ModuleList()
for stage_idx in range(num_stages - 1):
# ์—ฌ๊ธฐ์—์„œ stage_idx ๋Š” ์ด์ „ stage ๋ฅผ ๋œปํ•จ. transition ์€ ๊ฐ stage ์‚ฌ์ด์—์„œ channel ์„ ๋ฐ”๊ฟ”์ฃผ๊ฑฐ๋‚˜ ์ƒˆ๋กœ์šด branch ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์—ญํ• 
transition_by_stage = nn.ModuleList()
psc = in_channels_by_stage[stage_idx] # psc: prev_stage_channels
nsc = in_channels_by_stage[stage_idx + 1] # nsc: next_stage_channels
for nsbi in range(stage_idx + 2): # nsbi: next_stage_branch_idx
if nsbi < stage_idx + 1: # ๋™์ผํ•œ branch level
if psc[nsbi] != nsc[nsbi]:
transition_by_stage.append(
nn.Sequential(
_gen_channel_change_conv(psc[nsbi], nsc[nsbi]),
nn.BatchNorm1d(nsc[nsbi]),
nn.ReLU(),
)
)
else:
transition_by_stage.append(nn.Identity())
else: # create new branch from exists branches
transition_from_branches = nn.ModuleList()
for psbi in range(nsbi):
# psbi: prev_stage_branch_idx
transition_from_one_branch = [
_gen_downsample(psc[psbi], nsc[nsbi]),
nn.BatchNorm1d(nsc[nsbi]),
]
for _ in range(1, nsbi - psbi):
transition_from_one_branch.extend(
[
nn.ReLU(),
_gen_downsample(nsc[nsbi], nsc[nsbi]),
nn.BatchNorm1d(nsc[nsbi]),
]
)
transition_from_branches.append(
nn.Sequential(*transition_from_one_branch)
)
transition_by_stage.append(transition_from_branches)
self.transitions.append(transition_by_stage)
self.cls = nn.Conv1d(sum(in_channels_each_stage), output_size, 1, bias=False)
def forward(self, input: torch.Tensor, y=None):
output: torch.Tensor = input
output = self.stem(output)
outputs = [output]
for stage_idx, stage in enumerate(self.stages):
for module_idx in range(self.num_modules[stage_idx]):
for branch_idx in range(stage_idx + 1):
outputs[branch_idx] = stage[module_idx].branches[branch_idx](
outputs[branch_idx]
)
fusion_outputs = []
for next in range(stage_idx + 1):
fusion_output_from_branches = []
for prev in range(stage_idx + 1):
fusion_output_from_branch: torch.Tensor = stage[
module_idx
].fusions[next][prev](outputs[prev])
fusion_output_from_branches.append(fusion_output_from_branch)
fusion_outputs.append(sum(fusion_output_from_branches))
outputs = fusion_outputs
if stage_idx < len(self.stages) - 1:
transition_outputs = []
for trans_idx, transition in enumerate(self.transitions[stage_idx]):
# transition ์—๋Š” ๋‹ค์Œ stage ์˜ branch ๊ฐœ์ˆ˜๋งŒํผ Sequential ์ด๋‚˜ ModuleList ๊ฐ€ ์กด์žฌ
# ์•ž์˜ Sequential ๋“ค์€ channel ๋งŒ ๋‹ค์Œ stage ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝํ•˜๊ฑฐ๋‚˜ ๊ธฐ์กด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ (Identity)
# ๋งˆ์ง€๋ง‰ ModuleList ๊ฐ branch ์˜ fusion ๊ฒฐ๊ณผ๋“ค์„ downsample ํ•œ ๊ฒฐ๊ณผ๋“ค๋กœ๋ถ€ํ„ฐ ์ƒˆ๋กœ์šด branch ๋ฅผ ์ƒ์„ฑ
if trans_idx < stage_idx + 1:
transition_outputs.append(transition(outputs[trans_idx]))
else:
transition_outputs.append(
sum(
[
transition_from_each_branch(output)
for transition_from_each_branch, output in zip(
transition, outputs
)
]
)
)
outputs = transition_outputs
# HRNetV2
outputs = [
F.interpolate(output, size=outputs[0].shape[-1], mode=self.interpolate_mode)
for output in outputs
]
output = torch.cat(outputs, dim=1)
return F.interpolate(
self.cls(output), size=input.shape[-1], mode=self.interpolate_mode
)