import torch.nn as nn import torch import torch.nn.functional as F from swin_transformer_3d import SwinTransformer3D class LongSpikeStreamEncoderConv(nn.Module): def __init__( self, # num_blocks, # block_channel, patch_size=(32,2,2), in_chans=128, embed_dim=96, depths=[2,2,6], num_heads=[3,6,12], patch_norm=False, out_indices=(0,1,2), frozen_stages=-1, new_version=3, ): super(LongSpikeStreamEncoderConv, self).__init__() self.num_blocks = in_chans // patch_size[0] # self.out_num_depths = self.num_blocks - 1 # self.block_channel = block_channel self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.depths = depths self.num_heads = num_heads self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages self.num_encoders = len(self.depths) self.out_channels = [self.embed_dim*(2**i) for i in range(self.num_encoders)] self.swin3d = SwinTransformer3D( patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim, depths=self.depths, num_heads=self.num_heads, out_indices=self.out_indices, frozen_stages=self.frozen_stages, new_version=new_version, ) self.patches_T = self.num_blocks # self.patch_T = self.patches_T // self.num_blocks # 1 self.conv_layers = nn.ModuleList() for i in range(self.num_encoders): conv_layer_i = nn.ModuleList() for ti in range(self.num_blocks): conv_layer_i.append(nn.Conv2d(self.out_channels[i], self.out_channels[i] // self.num_blocks, 1)) self.conv_layers.append(conv_layer_i) def forward(self, inputs): B, C, H, W = inputs.shape features = self.swin3d(inputs) outs = [] for i in range(self.num_encoders): out_layer_i = [] features_i = features[i].chunk(self.num_blocks, 2) B, C, T, H, W = features_i[0].shape # features_i = features_i.reshape(B, -1, H, W) for k in range(self.num_blocks): feature_k = features_i[k].reshape(B, -1, H, W) # B,C,H,W out_k = self.conv_layers[i][k](feature_k) out_layer_i.append(out_k) out_i = torch.cat(out_layer_i, dim=1) outs.append(out_i) return outs