|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
from SpikeT.model.model import BaseERGB2Depth |
|
from SpikeT.model.encoder_transformer import LongSpikeStreamEncoderConv |
|
from SpikeT.model.submodules import ResidualBlock, ConvLayer, UpsampleConvLayer |
|
|
|
|
|
def skip_concat(x1, x2): |
|
return torch.cat([x1, x2], dim=1) |
|
|
|
|
|
def skip_sum(x1, x2): |
|
return x1 + x2 |
|
|
|
|
|
def identity(x1, x2=None): |
|
return x1 |
|
|
|
|
|
class S2DepthTransformerUNetConv(BaseERGB2Depth): |
|
def __init__(self, config): |
|
super(S2DepthTransformerUNetConv, self).__init__(config) |
|
assert self.base_num_channels % 48 == 0 |
|
|
|
self.depths=[int(i) for i in config["swin_depths"]] |
|
self.num_encoders = len(self.depths) |
|
self.num_heads=[int(i) for i in config["swin_num_heads"]] |
|
self.patch_size=[int(i) for i in config["swin_patch_size"]] |
|
self.out_indices=[int(i) for i in config["swin_out_indices"]] |
|
self.ape=config["ape"] |
|
try: |
|
self.num_v = config["new_v"] |
|
except KeyError: |
|
self.num_v = 0 |
|
|
|
self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders-1) |
|
self.activation = getattr(torch, 'sigmoid') |
|
|
|
self.num_output_channels = 1 |
|
|
|
self.encoder = LongSpikeStreamEncoderConv( |
|
patch_size=self.patch_size, |
|
in_chans=self.num_bins_rgb, |
|
embed_dim=self.base_num_channels, |
|
depths=self.depths, |
|
num_heads=self.num_heads, |
|
out_indices=self.out_indices, |
|
new_version=self.num_v, |
|
) |
|
|
|
self.UpsampleLayer = UpsampleConvLayer |
|
|
|
if self.skip_type == 'sum': |
|
self.apply_skip_connection = skip_sum |
|
elif self.skip_type == 'concat': |
|
self.apply_skip_connection = skip_concat |
|
elif self.skip_type == 'no_skip' or self.skip_type is None: |
|
self.apply_skip_connection = identity |
|
else: |
|
raise KeyError('Could not identify skip_type, please add "skip_type":' |
|
' "sum", "concat" or "no_skip" to config["model"]') |
|
|
|
self.build_resblocks() |
|
self.build_decoders() |
|
self.build_prediction_layer() |
|
|
|
def build_resblocks(self): |
|
self.resblocks = nn.ModuleList() |
|
for i in range(self.num_residual_blocks): |
|
self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm)) |
|
|
|
|
|
def build_decoders(self): |
|
decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i) for i in range(self.num_encoders)])) |
|
|
|
self.decoders = nn.ModuleList() |
|
for input_size in decoder_input_sizes: |
|
self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size, |
|
input_size // 2, |
|
kernel_size=5, padding=2, norm=self.norm)) |
|
|
|
def build_prediction_layer(self): |
|
self.pred = ConvLayer(self.base_num_channels // 2 if self.skip_type == 'sum' else 2 * self.base_num_channels, |
|
self.num_output_channels, 1, activation=None, norm=self.norm) |
|
|
|
def forward_decoder(self, super_states): |
|
|
|
if not bool(self.baseline) and self.state_combination == "convlstm": |
|
x = super_states[-1][0] |
|
else: |
|
x = super_states[-1] |
|
|
|
for resblock in self.resblocks: |
|
x = resblock(x) |
|
|
|
|
|
for i, decoder in enumerate(self.decoders): |
|
if i == 0: |
|
x = decoder(x) |
|
|
|
else: |
|
if not bool(self.baseline) and self.state_combination == "convlstm": |
|
x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1][0])) |
|
else: |
|
x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1])) |
|
|
|
|
|
|
|
|
|
|
|
img = self.activation(self.pred(x)) |
|
|
|
return img |
|
|
|
def forward(self, item, prev_super_states, prev_states_lstm): |
|
|
|
""" |
|
:param spike_tensor: N x C x H x W |
|
:return: a predicted image of size N x 1 x H x W, taking values in [0,1]. |
|
""" |
|
predictions_dict = {} |
|
print(item.keys()) |
|
spike_tensor = item["image"].to(self.gpu) |
|
|
|
|
|
|
|
|
|
encoded_xs = self.encoder(spike_tensor) |
|
|
|
|
|
prediction = self.forward_decoder(encoded_xs) |
|
predictions_dict["image"] = prediction |
|
|
|
return predictions_dict, {'image': None}, prev_states_lstm |
|
|
|
|