rightnow / SpikeT /model /S2DepthNet.py
zzzzzeee's picture
Upload 101 files
5fc3d65 verified
import torch.nn as nn
import torch
import torch.nn.functional as F
from SpikeT.model.model import BaseERGB2Depth
from SpikeT.model.encoder_transformer import LongSpikeStreamEncoderConv
from SpikeT.model.submodules import ResidualBlock, ConvLayer, UpsampleConvLayer
def skip_concat(x1, x2):
return torch.cat([x1, x2], dim=1)
def skip_sum(x1, x2):
return x1 + x2
def identity(x1, x2=None):
return x1
class S2DepthTransformerUNetConv(BaseERGB2Depth):
def __init__(self, config):
super(S2DepthTransformerUNetConv, self).__init__(config)
assert self.base_num_channels % 48 == 0
self.depths=[int(i) for i in config["swin_depths"]]
self.num_encoders = len(self.depths)
self.num_heads=[int(i) for i in config["swin_num_heads"]]
self.patch_size=[int(i) for i in config["swin_patch_size"]]
self.out_indices=[int(i) for i in config["swin_out_indices"]]
self.ape=config["ape"]
try:
self.num_v = config["new_v"]
except KeyError:
self.num_v = 0
self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders-1)
self.activation = getattr(torch, 'sigmoid')
# self.num_channel_spikes = config["num_channel_spikes"]
self.num_output_channels = 1
self.encoder = LongSpikeStreamEncoderConv(
patch_size=self.patch_size,
in_chans=self.num_bins_rgb,
embed_dim=self.base_num_channels,
depths=self.depths,
num_heads=self.num_heads,
out_indices=self.out_indices,
new_version=self.num_v,
)
self.UpsampleLayer = UpsampleConvLayer
if self.skip_type == 'sum':
self.apply_skip_connection = skip_sum
elif self.skip_type == 'concat':
self.apply_skip_connection = skip_concat
elif self.skip_type == 'no_skip' or self.skip_type is None:
self.apply_skip_connection = identity
else:
raise KeyError('Could not identify skip_type, please add "skip_type":'
' "sum", "concat" or "no_skip" to config["model"]')
self.build_resblocks()
self.build_decoders()
self.build_prediction_layer()
def build_resblocks(self):
self.resblocks = nn.ModuleList()
for i in range(self.num_residual_blocks):
self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm))
def build_decoders(self):
decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i) for i in range(self.num_encoders)]))
self.decoders = nn.ModuleList()
for input_size in decoder_input_sizes:
self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size,
input_size // 2,
kernel_size=5, padding=2, norm=self.norm))
def build_prediction_layer(self):
self.pred = ConvLayer(self.base_num_channels // 2 if self.skip_type == 'sum' else 2 * self.base_num_channels,
self.num_output_channels, 1, activation=None, norm=self.norm)
def forward_decoder(self, super_states):
# last superstate is taken as input for decoder.
if not bool(self.baseline) and self.state_combination == "convlstm":
x = super_states[-1][0]
else:
x = super_states[-1]
# residual blocks
for resblock in self.resblocks:
x = resblock(x)
# decoder
for i, decoder in enumerate(self.decoders):
if i == 0:
x = decoder(x)
# print(x.shape)
else:
if not bool(self.baseline) and self.state_combination == "convlstm":
x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1][0]))
else:
x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1]))
# print(x.shape)
# x = decoder(x)
# tail
# img = self.activation(self.pred(self.apply_skip_connection(x, head)))
img = self.activation(self.pred(x))
return img
def forward(self, item, prev_super_states, prev_states_lstm):
#def forward(self, spike_tensor, prev_states=None):
"""
:param spike_tensor: N x C x H x W
:return: a predicted image of size N x 1 x H x W, taking values in [0,1].
"""
predictions_dict = {}
print(item.keys())
spike_tensor = item["image"].to(self.gpu)
# if "image" in list(item.keys()):
# spike_tensor = item["image"].to(self.gpu)
# else:
# spike_tensor = item["image"].to(self.gpu)
encoded_xs = self.encoder(spike_tensor)
# for x in encoded_xs:
# print(x.shape)
prediction = self.forward_decoder(encoded_xs)
predictions_dict["image"] = prediction
return predictions_dict, {'image': None}, prev_states_lstm