|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from SpikeT.model.swin_transformer_3d import SwinTransformer3D |
|
|
|
|
|
class LongSpikeStreamEncoderConv(nn.Module): |
|
def __init__( |
|
self, |
|
|
|
|
|
patch_size=(32,2,2), |
|
in_chans=128, |
|
embed_dim=96, |
|
depths=[2,2,6], |
|
num_heads=[3,6,12], |
|
patch_norm=False, |
|
out_indices=(0,1,2), |
|
frozen_stages=-1, |
|
new_version=3, |
|
): |
|
super(LongSpikeStreamEncoderConv, self).__init__() |
|
|
|
|
|
self.num_blocks = in_chans // patch_size[0] |
|
|
|
|
|
self.patch_size = patch_size |
|
self.in_chans = in_chans |
|
self.embed_dim = embed_dim |
|
self.depths = depths |
|
self.num_heads = num_heads |
|
self.patch_norm = patch_norm |
|
self.out_indices = out_indices |
|
self.frozen_stages = frozen_stages |
|
|
|
self.num_encoders = len(self.depths) |
|
self.out_channels = [self.embed_dim*(2**i) for i in range(self.num_encoders)] |
|
|
|
self.swin3d = SwinTransformer3D( |
|
patch_size=self.patch_size, |
|
in_chans=self.in_chans, |
|
embed_dim=self.embed_dim, |
|
depths=self.depths, |
|
num_heads=self.num_heads, |
|
out_indices=self.out_indices, |
|
frozen_stages=self.frozen_stages, |
|
new_version=new_version, |
|
) |
|
|
|
self.patches_T = self.num_blocks |
|
|
|
|
|
self.conv_layers = nn.ModuleList() |
|
for i in range(self.num_encoders): |
|
conv_layer_i = nn.ModuleList() |
|
for ti in range(self.num_blocks): |
|
conv_layer_i.append(nn.Conv2d(self.out_channels[i], self.out_channels[i] // self.num_blocks, 1)) |
|
self.conv_layers.append(conv_layer_i) |
|
|
|
def forward(self, inputs): |
|
B, C, H, W = inputs.shape |
|
|
|
features = self.swin3d(inputs) |
|
|
|
outs = [] |
|
for i in range(self.num_encoders): |
|
out_layer_i = [] |
|
features_i = features[i].chunk(self.num_blocks, 2) |
|
B, C, T, H, W = features_i[0].shape |
|
|
|
for k in range(self.num_blocks): |
|
feature_k = features_i[k].reshape(B, -1, H, W) |
|
out_k = self.conv_layers[i][k](feature_k) |
|
out_layer_i.append(out_k) |
|
out_i = torch.cat(out_layer_i, dim=1) |
|
outs.append(out_i) |
|
|
|
return outs |
|
|