File size: 5,173 Bytes
5fc3d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch.nn as nn
import torch
import torch.nn.functional as F


from SpikeT.model.model import BaseERGB2Depth
from SpikeT.model.encoder_transformer import LongSpikeStreamEncoderConv
from SpikeT.model.submodules import ResidualBlock, ConvLayer, UpsampleConvLayer


def skip_concat(x1, x2):
    return torch.cat([x1, x2], dim=1)


def skip_sum(x1, x2):
    return x1 + x2


def identity(x1, x2=None):
    return x1


class S2DepthTransformerUNetConv(BaseERGB2Depth):
    def __init__(self, config):
        super(S2DepthTransformerUNetConv, self).__init__(config)
        assert self.base_num_channels % 48 == 0

        self.depths=[int(i) for i in config["swin_depths"]]
        self.num_encoders = len(self.depths)
        self.num_heads=[int(i) for i in config["swin_num_heads"]]
        self.patch_size=[int(i) for i in config["swin_patch_size"]]
        self.out_indices=[int(i) for i in config["swin_out_indices"]]
        self.ape=config["ape"]
        try:
            self.num_v = config["new_v"]
        except KeyError:
            self.num_v = 0

        self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders-1)
        self.activation = getattr(torch, 'sigmoid')
        # self.num_channel_spikes = config["num_channel_spikes"]
        self.num_output_channels = 1

        self.encoder = LongSpikeStreamEncoderConv(
            patch_size=self.patch_size,
            in_chans=self.num_bins_rgb,
            embed_dim=self.base_num_channels,
            depths=self.depths,
            num_heads=self.num_heads,
            out_indices=self.out_indices,
            new_version=self.num_v,
        )

        self.UpsampleLayer = UpsampleConvLayer

        if self.skip_type == 'sum':
            self.apply_skip_connection = skip_sum
        elif self.skip_type == 'concat':
            self.apply_skip_connection = skip_concat
        elif self.skip_type == 'no_skip' or self.skip_type is None:
            self.apply_skip_connection = identity
        else:
            raise KeyError('Could not identify skip_type, please add "skip_type":'
                           ' "sum", "concat" or "no_skip" to config["model"]')

        self.build_resblocks()
        self.build_decoders()
        self.build_prediction_layer()
    
    def build_resblocks(self):
        self.resblocks = nn.ModuleList()
        for i in range(self.num_residual_blocks):
            self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm))
    

    def build_decoders(self):
        decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i) for i in range(self.num_encoders)]))

        self.decoders = nn.ModuleList()
        for input_size in decoder_input_sizes:
            self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size,
                                                    input_size // 2,
                                                    kernel_size=5, padding=2, norm=self.norm))

    def build_prediction_layer(self):
        self.pred = ConvLayer(self.base_num_channels // 2 if self.skip_type == 'sum' else 2 * self.base_num_channels,
                              self.num_output_channels, 1, activation=None, norm=self.norm)

    def forward_decoder(self, super_states):
        # last superstate is taken as input for decoder.
        if not bool(self.baseline) and self.state_combination == "convlstm":
            x = super_states[-1][0]
        else:
            x = super_states[-1]
        # residual blocks
        for resblock in self.resblocks:
            x = resblock(x)

        # decoder
        for i, decoder in enumerate(self.decoders):
            if i == 0:
                x = decoder(x)
                # print(x.shape)
            else:
                if not bool(self.baseline) and self.state_combination == "convlstm":
                    x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1][0]))
                else:
                    x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1]))
                    # print(x.shape)
            # x = decoder(x)

        # tail
        # img = self.activation(self.pred(self.apply_skip_connection(x, head)))
        img = self.activation(self.pred(x))

        return img

    def forward(self, item, prev_super_states, prev_states_lstm):
        #def forward(self, spike_tensor, prev_states=None):
        """
        :param spike_tensor: N x C x H x W
        :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
        """
        predictions_dict = {}
        print(item.keys())
        spike_tensor = item["image"].to(self.gpu)
        # if "image" in list(item.keys()):
        #     spike_tensor = item["image"].to(self.gpu)
        # else:
        #     spike_tensor = item["image"].to(self.gpu)
        encoded_xs = self.encoder(spike_tensor)
        # for x in encoded_xs:
            # print(x.shape)
        prediction = self.forward_decoder(encoded_xs)
        predictions_dict["image"] = prediction

        return predictions_dict, {'image': None}, prev_states_lstm