File size: 2,617 Bytes
05c56b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn as nn
import torch
import torch.nn.functional as F

from swin_transformer_3d import SwinTransformer3D


class LongSpikeStreamEncoderConv(nn.Module):
    def __init__(
        self,
        # num_blocks,
        # block_channel,
        patch_size=(32,2,2), 
        in_chans=128, 
        embed_dim=96, 
        depths=[2,2,6],
        num_heads=[3,6,12],
        patch_norm=False,
        out_indices=(0,1,2),
        frozen_stages=-1,
        new_version=3,
        ):
        super(LongSpikeStreamEncoderConv, self).__init__()

        
        self.num_blocks = in_chans // patch_size[0]
        # self.out_num_depths = self.num_blocks - 1
        # self.block_channel = block_channel
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.depths = depths
        self.num_heads = num_heads
        self.patch_norm = patch_norm
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages

        self.num_encoders = len(self.depths)
        self.out_channels = [self.embed_dim*(2**i) for i in range(self.num_encoders)]

        self.swin3d = SwinTransformer3D(
            patch_size=self.patch_size,
            in_chans=self.in_chans,
            embed_dim=self.embed_dim,
            depths=self.depths,
            num_heads=self.num_heads,
            out_indices=self.out_indices,
            frozen_stages=self.frozen_stages,
            new_version=new_version,
        )

        self.patches_T = self.num_blocks
        # self.patch_T = self.patches_T // self.num_blocks  # 1

        self.conv_layers = nn.ModuleList()
        for i in range(self.num_encoders):
            conv_layer_i = nn.ModuleList()
            for ti in range(self.num_blocks):
                conv_layer_i.append(nn.Conv2d(self.out_channels[i], self.out_channels[i] // self.num_blocks, 1))
            self.conv_layers.append(conv_layer_i)

    def forward(self, inputs):
        B, C, H, W = inputs.shape

        features = self.swin3d(inputs)

        outs = []
        for i in range(self.num_encoders):
            out_layer_i = []
            features_i = features[i].chunk(self.num_blocks, 2)
            B, C, T, H, W = features_i[0].shape
            # features_i = features_i.reshape(B, -1, H, W)
            for k in range(self.num_blocks):
                feature_k = features_i[k].reshape(B, -1, H, W) # B,C,H,W
                out_k = self.conv_layers[i][k](feature_k)
                out_layer_i.append(out_k)
            out_i = torch.cat(out_layer_i, dim=1)
            outs.append(out_i)

        return outs