Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/1904.04514 | |
ref: https://github.com/HRNet/HRNet-Semantic-Segmentation/blob/HRNet-OCR/lib/models/seg_hrnet.py | |
""" | |
import torch | |
import torch.nn as nn | |
from torch.functional import F | |
import math | |
def _gen_same_length_conv(in_channel, out_channel, kernel_size=1, dilation=1): | |
"""๊ธธ์ด๊ฐ ๋ณํ์ง ์๋ conv ์์ฑ, block ๋ด์์ feature ๋ฅผ ์ถ์ถํ๋ convolution ์์ ์ฌ์ฉ""" | |
return nn.Conv1d( | |
in_channel, | |
out_channel, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=(dilation * (kernel_size - 1)) // 2, | |
dilation=dilation, | |
bias=False, | |
) | |
def _gen_downsample(in_channel, out_channel): | |
"""kernel_size:3, stride:2, padding:1 ์ธ 2๋ฐฐ downsample ํ๋ conv ์์ฑ""" | |
return nn.Conv1d( | |
in_channel, out_channel, kernel_size=3, stride=2, padding=1, bias=False | |
) | |
def _gen_channel_change_conv(in_channel, out_channel): | |
"""kernel_size:1, stride:1 ์ธ channel ๋ณ๊ฒฝํ๋ conv ์์ฑ""" | |
return nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, bias=False) | |
class BasicBlock(nn.Module): | |
"""resnet ์ basic block ์ผ๋ก channel ๋ณํ๋ inplanes -> planes""" | |
expansion = 1 | |
def __init__(self, inplanes, planes, kernel_size=3, dilation=1): | |
super().__init__() | |
self.conv1 = _gen_same_length_conv(inplanes, planes, kernel_size, dilation) | |
self.bn1 = nn.BatchNorm1d(planes) | |
self.relu = nn.ReLU() | |
self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation) | |
self.bn2 = nn.BatchNorm1d(planes) | |
self.make_residual = ( | |
_gen_channel_change_conv(inplanes, planes) | |
if inplanes != planes | |
else nn.Identity() | |
) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
residual = self.make_residual(x) | |
out = out + residual | |
out = self.relu(out) | |
return out | |
class Bottleneck(nn.Module): | |
"""resnet ์ Bottleneck block ์ผ๋ก channel ๋ณํ๋ inplanes -> planes * 4""" | |
expansion = 4 | |
def __init__(self, inplanes, planes, kernel_size=3, dilation=1): | |
super().__init__() | |
self.conv1 = _gen_same_length_conv(inplanes, planes) | |
self.bn1 = nn.BatchNorm1d(planes) | |
self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation) | |
self.bn2 = nn.BatchNorm1d(planes) | |
self.conv3 = _gen_same_length_conv(planes, planes * self.expansion) | |
self.bn3 = nn.BatchNorm1d(planes * self.expansion) | |
self.relu = nn.ReLU() | |
self.make_residual = ( | |
_gen_channel_change_conv(inplanes, planes * self.expansion) | |
if inplanes != planes * self.expansion | |
else nn.Identity() | |
) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
residual = self.make_residual(x) | |
out = out + residual | |
out = self.relu(out) | |
return out | |
class HRModule(nn.Module): | |
def __init__( | |
self, | |
stage_idx, | |
num_blocks, | |
block_type_by_stage, | |
in_channels_by_stage, | |
out_channels_by_stage, | |
data_len_by_branch, | |
kernel_size, | |
dilation, | |
interpolate_mode, | |
): | |
super().__init__() | |
self.branches = nn.ModuleList() | |
self.fusions = nn.ModuleList() | |
block_type: BasicBlock | Bottleneck = block_type_by_stage[stage_idx] | |
in_channels = in_channels_by_stage[stage_idx] | |
for i in range(stage_idx + 1): # branch ์์ฑ | |
blocks_by_branch = [] | |
_channels = in_channels[i] | |
blocks_by_branch.append( | |
block_type(_channels, _channels, kernel_size, dilation) | |
) | |
for _ in range(1, num_blocks): | |
blocks_by_branch.append( | |
block_type( | |
_channels * block_type.expansion, | |
_channels, | |
kernel_size, | |
dilation, | |
) | |
) | |
self.branches.append(nn.Sequential(*blocks_by_branch)) | |
out_channels = out_channels_by_stage[stage_idx] | |
for i in range(stage_idx + 1): | |
fusion_by_branch = nn.ModuleList() | |
for j in range(stage_idx + 1): | |
if i < j: | |
fusion_by_branch.append( | |
nn.Sequential( | |
_gen_channel_change_conv(out_channels[j], in_channels[i]), | |
nn.BatchNorm1d(in_channels[i]), | |
nn.Upsample( | |
size=data_len_by_branch[i], mode=interpolate_mode | |
), | |
) | |
) | |
elif i == j: | |
if out_channels[i] != in_channels[j]: | |
fusion_by_branch.append( | |
nn.Sequential( | |
_gen_channel_change_conv( | |
out_channels[i], in_channels[j] | |
), | |
nn.BatchNorm1d(in_channels[j]), | |
nn.ReLU(), | |
) | |
) | |
else: | |
fusion_by_branch.append(nn.Identity()) | |
else: | |
# ์ฐจ์ด๋๋ branch ๋งํผ 2๋ฐฐ์ฉ downsample, channel ์ ํ์ฌ layer ์ in_channel ๋ก ๋ง์ถฐ์ค | |
downsamples = [ | |
_gen_downsample(out_channels[j], in_channels[i]), | |
nn.BatchNorm1d(in_channels[i]), | |
] | |
for _ in range(1, i - j): | |
downsamples.extend( | |
[ | |
nn.ReLU(), | |
_gen_downsample(in_channels[i], in_channels[i]), | |
nn.BatchNorm1d(in_channels[i]), | |
] | |
) | |
fusion_by_branch.append(nn.Sequential(*downsamples)) | |
self.fusions.append(fusion_by_branch) | |
class HRNetV2(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
data_len = int(config.data_len) # ECGPQRSTDataset.second, hz ์ ๋ง์ถฐ์ | |
kernel_size = int(config.kernel_size) | |
dilation = int(config.dilation) | |
num_stages = int(config.num_stages) | |
num_blocks = int(config.num_blocks) | |
self.num_modules = config.num_modules # [1, 1, 4, 3, ..] | |
assert num_stages <= len(self.num_modules) | |
use_bottleneck = config.use_bottleneck # [1, 0, 0, 0, ..] | |
assert num_stages <= len(use_bottleneck) | |
stage1_channels = int(config.stage1_channels) # 64, 128 | |
num_channels_init = int(config.num_channels_init) # 18, 32, 48 | |
self.interpolate_mode = config.interpolate_mode | |
output_size = config.output_size # 3(p, qrs, t) | |
# stem | |
self.stem = nn.Sequential( | |
nn.Conv1d( | |
1, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False | |
), | |
nn.BatchNorm1d(stage1_channels), | |
nn.Conv1d( | |
stage1_channels, | |
stage1_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False, | |
), | |
nn.BatchNorm1d(stage1_channels), | |
nn.ReLU(), | |
) | |
for _ in range(2): # stem ์ ๊ฑฐ์น ์ดํ ๋ฐ์ดํฐ ๊ธธ์ด ๊ณ์ฐ | |
data_len = math.floor((data_len - 1) / 2 + 1) | |
# create meta: ๋คํธ์ํฌ ์์ฑ ์ ๊ฐ stage ์ in_channel, out_channel ๋ฑ์ ์ ๋ณด๋ฅผ ๋จผ์ ๋ง๋ค๊ณ ์์ | |
in_channels_by_stage = [] | |
out_channels_by_stage = [] | |
block_type_by_stage = [] | |
for stage_idx in range(num_stages): | |
block_type_each_stage = ( | |
Bottleneck if use_bottleneck[stage_idx] == 1 else BasicBlock | |
) | |
if stage_idx == 0: | |
in_channels_each_stage = [stage1_channels] | |
out_channels_each_stage = [ | |
stage1_channels * block_type_each_stage.expansion | |
] | |
data_len_by_branch = [data_len] | |
else: | |
in_channels_each_stage = [ | |
num_channels_init * 2**idx for idx in range(stage_idx + 1) | |
] | |
out_channels_each_stage = [ | |
(num_channels_init * 2**idx) * block_type_each_stage.expansion | |
for idx in range(stage_idx + 1) | |
] | |
data_len_by_branch.append( | |
math.floor((data_len_by_branch[-1] - 1) / 2 + 1) | |
) | |
block_type_by_stage.append(block_type_each_stage) | |
in_channels_by_stage.append(in_channels_each_stage) | |
out_channels_by_stage.append(out_channels_each_stage) | |
# create stages | |
self.stages = nn.ModuleList() | |
for stage_idx in range(num_stages): | |
modules_by_stage = nn.ModuleList() | |
for _ in range(self.num_modules[stage_idx]): | |
modules_by_stage.append( | |
HRModule( | |
stage_idx, | |
num_blocks, | |
block_type_by_stage, | |
in_channels_by_stage, | |
out_channels_by_stage, | |
data_len_by_branch, | |
kernel_size, | |
dilation, | |
self.interpolate_mode, | |
) | |
) | |
self.stages.append(modules_by_stage) | |
# create transition | |
self.transitions = nn.ModuleList() | |
for stage_idx in range(num_stages - 1): | |
# ์ฌ๊ธฐ์์ stage_idx ๋ ์ด์ stage ๋ฅผ ๋ปํจ. transition ์ ๊ฐ stage ์ฌ์ด์์ channel ์ ๋ฐ๊ฟ์ฃผ๊ฑฐ๋ ์๋ก์ด branch ๋ฅผ ์์ฑํ๋ ์ญํ | |
transition_by_stage = nn.ModuleList() | |
psc = in_channels_by_stage[stage_idx] # psc: prev_stage_channels | |
nsc = in_channels_by_stage[stage_idx + 1] # nsc: next_stage_channels | |
for nsbi in range(stage_idx + 2): # nsbi: next_stage_branch_idx | |
if nsbi < stage_idx + 1: # ๋์ผํ branch level | |
if psc[nsbi] != nsc[nsbi]: | |
transition_by_stage.append( | |
nn.Sequential( | |
_gen_channel_change_conv(psc[nsbi], nsc[nsbi]), | |
nn.BatchNorm1d(nsc[nsbi]), | |
nn.ReLU(), | |
) | |
) | |
else: | |
transition_by_stage.append(nn.Identity()) | |
else: # create new branch from exists branches | |
transition_from_branches = nn.ModuleList() | |
for psbi in range(nsbi): | |
# psbi: prev_stage_branch_idx | |
transition_from_one_branch = [ | |
_gen_downsample(psc[psbi], nsc[nsbi]), | |
nn.BatchNorm1d(nsc[nsbi]), | |
] | |
for _ in range(1, nsbi - psbi): | |
transition_from_one_branch.extend( | |
[ | |
nn.ReLU(), | |
_gen_downsample(nsc[nsbi], nsc[nsbi]), | |
nn.BatchNorm1d(nsc[nsbi]), | |
] | |
) | |
transition_from_branches.append( | |
nn.Sequential(*transition_from_one_branch) | |
) | |
transition_by_stage.append(transition_from_branches) | |
self.transitions.append(transition_by_stage) | |
self.cls = nn.Conv1d(sum(in_channels_each_stage), output_size, 1, bias=False) | |
def forward(self, input: torch.Tensor, y=None): | |
output: torch.Tensor = input | |
output = self.stem(output) | |
outputs = [output] | |
for stage_idx, stage in enumerate(self.stages): | |
for module_idx in range(self.num_modules[stage_idx]): | |
for branch_idx in range(stage_idx + 1): | |
outputs[branch_idx] = stage[module_idx].branches[branch_idx]( | |
outputs[branch_idx] | |
) | |
fusion_outputs = [] | |
for next in range(stage_idx + 1): | |
fusion_output_from_branches = [] | |
for prev in range(stage_idx + 1): | |
fusion_output_from_branch: torch.Tensor = stage[ | |
module_idx | |
].fusions[next][prev](outputs[prev]) | |
fusion_output_from_branches.append(fusion_output_from_branch) | |
fusion_outputs.append(sum(fusion_output_from_branches)) | |
outputs = fusion_outputs | |
if stage_idx < len(self.stages) - 1: | |
transition_outputs = [] | |
for trans_idx, transition in enumerate(self.transitions[stage_idx]): | |
# transition ์๋ ๋ค์ stage ์ branch ๊ฐ์๋งํผ Sequential ์ด๋ ModuleList ๊ฐ ์กด์ฌ | |
# ์์ Sequential ๋ค์ channel ๋ง ๋ค์ stage ์ ๋ง๊ฒ ๋ณ๊ฒฝํ๊ฑฐ๋ ๊ธฐ์กด ๊ทธ๋๋ก ์ฌ์ฉ (Identity) | |
# ๋ง์ง๋ง ModuleList ๊ฐ branch ์ fusion ๊ฒฐ๊ณผ๋ค์ downsample ํ ๊ฒฐ๊ณผ๋ค๋ก๋ถํฐ ์๋ก์ด branch ๋ฅผ ์์ฑ | |
if trans_idx < stage_idx + 1: | |
transition_outputs.append(transition(outputs[trans_idx])) | |
else: | |
transition_outputs.append( | |
sum( | |
[ | |
transition_from_each_branch(output) | |
for transition_from_each_branch, output in zip( | |
transition, outputs | |
) | |
] | |
) | |
) | |
outputs = transition_outputs | |
# HRNetV2 | |
outputs = [ | |
F.interpolate(output, size=outputs[0].shape[-1], mode=self.interpolate_mode) | |
for output in outputs | |
] | |
output = torch.cat(outputs, dim=1) | |
return F.interpolate( | |
self.cls(output), size=input.shape[-1], mode=self.interpolate_mode | |
) | |