""" paper: https://arxiv.org/abs/1904.04514 ref: https://github.com/HRNet/HRNet-Semantic-Segmentation/blob/HRNet-OCR/lib/models/seg_hrnet.py """ import torch import torch.nn as nn from torch.functional import F import math def _gen_same_length_conv(in_channel, out_channel, kernel_size=1, dilation=1): """길이가 변하지 않는 conv 생성, block 내에서 feature 를 추출하는 convolution 에서 사용""" return nn.Conv1d( in_channel, out_channel, kernel_size=kernel_size, stride=1, padding=(dilation * (kernel_size - 1)) // 2, dilation=dilation, bias=False, ) def _gen_downsample(in_channel, out_channel): """kernel_size:3, stride:2, padding:1 인 2배 downsample 하는 conv 생성""" return nn.Conv1d( in_channel, out_channel, kernel_size=3, stride=2, padding=1, bias=False ) def _gen_channel_change_conv(in_channel, out_channel): """kernel_size:1, stride:1 인 channel 변경하는 conv 생성""" return nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, bias=False) class BasicBlock(nn.Module): """resnet 의 basic block 으로 channel 변화는 inplanes -> planes""" expansion = 1 def __init__(self, inplanes, planes, kernel_size=3, dilation=1): super().__init__() self.conv1 = _gen_same_length_conv(inplanes, planes, kernel_size, dilation) self.bn1 = nn.BatchNorm1d(planes) self.relu = nn.ReLU() self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation) self.bn2 = nn.BatchNorm1d(planes) self.make_residual = ( _gen_channel_change_conv(inplanes, planes) if inplanes != planes else nn.Identity() ) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) residual = self.make_residual(x) out = out + residual out = self.relu(out) return out class Bottleneck(nn.Module): """resnet 의 Bottleneck block 으로 channel 변화는 inplanes -> planes * 4""" expansion = 4 def __init__(self, inplanes, planes, kernel_size=3, dilation=1): super().__init__() self.conv1 = _gen_same_length_conv(inplanes, planes) self.bn1 = nn.BatchNorm1d(planes) self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation) self.bn2 = nn.BatchNorm1d(planes) self.conv3 = _gen_same_length_conv(planes, planes * self.expansion) self.bn3 = nn.BatchNorm1d(planes * self.expansion) self.relu = nn.ReLU() self.make_residual = ( _gen_channel_change_conv(inplanes, planes * self.expansion) if inplanes != planes * self.expansion else nn.Identity() ) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) residual = self.make_residual(x) out = out + residual out = self.relu(out) return out class HRModule(nn.Module): def __init__( self, stage_idx, num_blocks, block_type_by_stage, in_channels_by_stage, out_channels_by_stage, data_len_by_branch, kernel_size, dilation, interpolate_mode, ): super().__init__() self.branches = nn.ModuleList() self.fusions = nn.ModuleList() block_type: BasicBlock | Bottleneck = block_type_by_stage[stage_idx] in_channels = in_channels_by_stage[stage_idx] for i in range(stage_idx + 1): # branch 생성 blocks_by_branch = [] _channels = in_channels[i] blocks_by_branch.append( block_type(_channels, _channels, kernel_size, dilation) ) for _ in range(1, num_blocks): blocks_by_branch.append( block_type( _channels * block_type.expansion, _channels, kernel_size, dilation, ) ) self.branches.append(nn.Sequential(*blocks_by_branch)) out_channels = out_channels_by_stage[stage_idx] for i in range(stage_idx + 1): fusion_by_branch = nn.ModuleList() for j in range(stage_idx + 1): if i < j: fusion_by_branch.append( nn.Sequential( _gen_channel_change_conv(out_channels[j], in_channels[i]), nn.BatchNorm1d(in_channels[i]), nn.Upsample( size=data_len_by_branch[i], mode=interpolate_mode ), ) ) elif i == j: if out_channels[i] != in_channels[j]: fusion_by_branch.append( nn.Sequential( _gen_channel_change_conv( out_channels[i], in_channels[j] ), nn.BatchNorm1d(in_channels[j]), nn.ReLU(), ) ) else: fusion_by_branch.append(nn.Identity()) else: # 차이나는 branch 만큼 2배씩 downsample, channel 은 현재 layer 의 in_channel 로 맞춰줌 downsamples = [ _gen_downsample(out_channels[j], in_channels[i]), nn.BatchNorm1d(in_channels[i]), ] for _ in range(1, i - j): downsamples.extend( [ nn.ReLU(), _gen_downsample(in_channels[i], in_channels[i]), nn.BatchNorm1d(in_channels[i]), ] ) fusion_by_branch.append(nn.Sequential(*downsamples)) self.fusions.append(fusion_by_branch) class HRNetV2(nn.Module): def __init__(self, config): super().__init__() self.config = config data_len = int(config.data_len) # ECGPQRSTDataset.second, hz 에 맞춰서 kernel_size = int(config.kernel_size) dilation = int(config.dilation) num_stages = int(config.num_stages) num_blocks = int(config.num_blocks) self.num_modules = config.num_modules # [1, 1, 4, 3, ..] assert num_stages <= len(self.num_modules) use_bottleneck = config.use_bottleneck # [1, 0, 0, 0, ..] assert num_stages <= len(use_bottleneck) stage1_channels = int(config.stage1_channels) # 64, 128 num_channels_init = int(config.num_channels_init) # 18, 32, 48 self.interpolate_mode = config.interpolate_mode output_size = config.output_size # 3(p, qrs, t) # stem self.stem = nn.Sequential( nn.Conv1d( 1, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False ), nn.BatchNorm1d(stage1_channels), nn.Conv1d( stage1_channels, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False, ), nn.BatchNorm1d(stage1_channels), nn.ReLU(), ) for _ in range(2): # stem 을 거친 이후 데이터 길이 계산 data_len = math.floor((data_len - 1) / 2 + 1) # create meta: 네트워크 생성 전 각 stage 의 in_channel, out_channel 등의 정보를 먼저 만들고 시작 in_channels_by_stage = [] out_channels_by_stage = [] block_type_by_stage = [] for stage_idx in range(num_stages): block_type_each_stage = ( Bottleneck if use_bottleneck[stage_idx] == 1 else BasicBlock ) if stage_idx == 0: in_channels_each_stage = [stage1_channels] out_channels_each_stage = [ stage1_channels * block_type_each_stage.expansion ] data_len_by_branch = [data_len] else: in_channels_each_stage = [ num_channels_init * 2**idx for idx in range(stage_idx + 1) ] out_channels_each_stage = [ (num_channels_init * 2**idx) * block_type_each_stage.expansion for idx in range(stage_idx + 1) ] data_len_by_branch.append( math.floor((data_len_by_branch[-1] - 1) / 2 + 1) ) block_type_by_stage.append(block_type_each_stage) in_channels_by_stage.append(in_channels_each_stage) out_channels_by_stage.append(out_channels_each_stage) # create stages self.stages = nn.ModuleList() for stage_idx in range(num_stages): modules_by_stage = nn.ModuleList() for _ in range(self.num_modules[stage_idx]): modules_by_stage.append( HRModule( stage_idx, num_blocks, block_type_by_stage, in_channels_by_stage, out_channels_by_stage, data_len_by_branch, kernel_size, dilation, self.interpolate_mode, ) ) self.stages.append(modules_by_stage) # create transition self.transitions = nn.ModuleList() for stage_idx in range(num_stages - 1): # 여기에서 stage_idx 는 이전 stage 를 뜻함. transition 은 각 stage 사이에서 channel 을 바꿔주거나 새로운 branch 를 생성하는 역할 transition_by_stage = nn.ModuleList() psc = in_channels_by_stage[stage_idx] # psc: prev_stage_channels nsc = in_channels_by_stage[stage_idx + 1] # nsc: next_stage_channels for nsbi in range(stage_idx + 2): # nsbi: next_stage_branch_idx if nsbi < stage_idx + 1: # 동일한 branch level if psc[nsbi] != nsc[nsbi]: transition_by_stage.append( nn.Sequential( _gen_channel_change_conv(psc[nsbi], nsc[nsbi]), nn.BatchNorm1d(nsc[nsbi]), nn.ReLU(), ) ) else: transition_by_stage.append(nn.Identity()) else: # create new branch from exists branches transition_from_branches = nn.ModuleList() for psbi in range(nsbi): # psbi: prev_stage_branch_idx transition_from_one_branch = [ _gen_downsample(psc[psbi], nsc[nsbi]), nn.BatchNorm1d(nsc[nsbi]), ] for _ in range(1, nsbi - psbi): transition_from_one_branch.extend( [ nn.ReLU(), _gen_downsample(nsc[nsbi], nsc[nsbi]), nn.BatchNorm1d(nsc[nsbi]), ] ) transition_from_branches.append( nn.Sequential(*transition_from_one_branch) ) transition_by_stage.append(transition_from_branches) self.transitions.append(transition_by_stage) self.cls = nn.Conv1d(sum(in_channels_each_stage), output_size, 1, bias=False) def forward(self, input: torch.Tensor, y=None): output: torch.Tensor = input output = self.stem(output) outputs = [output] for stage_idx, stage in enumerate(self.stages): for module_idx in range(self.num_modules[stage_idx]): for branch_idx in range(stage_idx + 1): outputs[branch_idx] = stage[module_idx].branches[branch_idx]( outputs[branch_idx] ) fusion_outputs = [] for next in range(stage_idx + 1): fusion_output_from_branches = [] for prev in range(stage_idx + 1): fusion_output_from_branch: torch.Tensor = stage[ module_idx ].fusions[next][prev](outputs[prev]) fusion_output_from_branches.append(fusion_output_from_branch) fusion_outputs.append(sum(fusion_output_from_branches)) outputs = fusion_outputs if stage_idx < len(self.stages) - 1: transition_outputs = [] for trans_idx, transition in enumerate(self.transitions[stage_idx]): # transition 에는 다음 stage 의 branch 개수만큼 Sequential 이나 ModuleList 가 존재 # 앞의 Sequential 들은 channel 만 다음 stage 에 맞게 변경하거나 기존 그대로 사용 (Identity) # 마지막 ModuleList 각 branch 의 fusion 결과들을 downsample 한 결과들로부터 새로운 branch 를 생성 if trans_idx < stage_idx + 1: transition_outputs.append(transition(outputs[trans_idx])) else: transition_outputs.append( sum( [ transition_from_each_branch(output) for transition_from_each_branch, output in zip( transition, outputs ) ] ) ) outputs = transition_outputs # HRNetV2 outputs = [ F.interpolate(output, size=outputs[0].shape[-1], mode=self.interpolate_mode) for output in outputs ] output = torch.cat(outputs, dim=1) return F.interpolate( self.cls(output), size=input.shape[-1], mode=self.interpolate_mode )