File size: 5,173 Bytes
5fc3d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch.nn as nn
import torch
import torch.nn.functional as F
from SpikeT.model.model import BaseERGB2Depth
from SpikeT.model.encoder_transformer import LongSpikeStreamEncoderConv
from SpikeT.model.submodules import ResidualBlock, ConvLayer, UpsampleConvLayer
def skip_concat(x1, x2):
return torch.cat([x1, x2], dim=1)
def skip_sum(x1, x2):
return x1 + x2
def identity(x1, x2=None):
return x1
class S2DepthTransformerUNetConv(BaseERGB2Depth):
def __init__(self, config):
super(S2DepthTransformerUNetConv, self).__init__(config)
assert self.base_num_channels % 48 == 0
self.depths=[int(i) for i in config["swin_depths"]]
self.num_encoders = len(self.depths)
self.num_heads=[int(i) for i in config["swin_num_heads"]]
self.patch_size=[int(i) for i in config["swin_patch_size"]]
self.out_indices=[int(i) for i in config["swin_out_indices"]]
self.ape=config["ape"]
try:
self.num_v = config["new_v"]
except KeyError:
self.num_v = 0
self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders-1)
self.activation = getattr(torch, 'sigmoid')
# self.num_channel_spikes = config["num_channel_spikes"]
self.num_output_channels = 1
self.encoder = LongSpikeStreamEncoderConv(
patch_size=self.patch_size,
in_chans=self.num_bins_rgb,
embed_dim=self.base_num_channels,
depths=self.depths,
num_heads=self.num_heads,
out_indices=self.out_indices,
new_version=self.num_v,
)
self.UpsampleLayer = UpsampleConvLayer
if self.skip_type == 'sum':
self.apply_skip_connection = skip_sum
elif self.skip_type == 'concat':
self.apply_skip_connection = skip_concat
elif self.skip_type == 'no_skip' or self.skip_type is None:
self.apply_skip_connection = identity
else:
raise KeyError('Could not identify skip_type, please add "skip_type":'
' "sum", "concat" or "no_skip" to config["model"]')
self.build_resblocks()
self.build_decoders()
self.build_prediction_layer()
def build_resblocks(self):
self.resblocks = nn.ModuleList()
for i in range(self.num_residual_blocks):
self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm))
def build_decoders(self):
decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i) for i in range(self.num_encoders)]))
self.decoders = nn.ModuleList()
for input_size in decoder_input_sizes:
self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size,
input_size // 2,
kernel_size=5, padding=2, norm=self.norm))
def build_prediction_layer(self):
self.pred = ConvLayer(self.base_num_channels // 2 if self.skip_type == 'sum' else 2 * self.base_num_channels,
self.num_output_channels, 1, activation=None, norm=self.norm)
def forward_decoder(self, super_states):
# last superstate is taken as input for decoder.
if not bool(self.baseline) and self.state_combination == "convlstm":
x = super_states[-1][0]
else:
x = super_states[-1]
# residual blocks
for resblock in self.resblocks:
x = resblock(x)
# decoder
for i, decoder in enumerate(self.decoders):
if i == 0:
x = decoder(x)
# print(x.shape)
else:
if not bool(self.baseline) and self.state_combination == "convlstm":
x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1][0]))
else:
x = decoder(self.apply_skip_connection(x, super_states[self.num_encoders - i - 1]))
# print(x.shape)
# x = decoder(x)
# tail
# img = self.activation(self.pred(self.apply_skip_connection(x, head)))
img = self.activation(self.pred(x))
return img
def forward(self, item, prev_super_states, prev_states_lstm):
#def forward(self, spike_tensor, prev_states=None):
"""
:param spike_tensor: N x C x H x W
:return: a predicted image of size N x 1 x H x W, taking values in [0,1].
"""
predictions_dict = {}
print(item.keys())
spike_tensor = item["image"].to(self.gpu)
# if "image" in list(item.keys()):
# spike_tensor = item["image"].to(self.gpu)
# else:
# spike_tensor = item["image"].to(self.gpu)
encoded_xs = self.encoder(spike_tensor)
# for x in encoded_xs:
# print(x.shape)
prediction = self.forward_decoder(encoded_xs)
predictions_dict["image"] = prediction
return predictions_dict, {'image': None}, prev_states_lstm
|