|
|
|
from SpikeT.base import BaseModel |
|
import torch.nn as nn |
|
import torch |
|
from os.path import join |
|
from SpikeT.model.submodules import \ |
|
ConvLSTM, ResidualBlock, ConvLayer, \ |
|
UpsampleConvLayer, TransposedConvLayer |
|
|
|
|
|
def skip_concat(x1, x2): |
|
return torch.cat([x1, x2], dim=1) |
|
|
|
|
|
def skip_sum(x1, x2): |
|
return x1 + x2 |
|
|
|
|
|
def identity(x1, x2=None): |
|
return x1 |
|
|
|
|
|
class BaseERGB2Depth(BaseModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
assert ('num_bins_rgb' in config) |
|
self.num_bins_rgb = int(config['num_bins_rgb']) |
|
assert ('num_bins_events' in config) |
|
self.num_bins_events = int(config['num_bins_events']) |
|
|
|
try: |
|
self.skip_type = str(config['skip_type']) |
|
except KeyError: |
|
self.skip_type = 'sum' |
|
|
|
try: |
|
self.state_combination = str(config['state_combination']) |
|
except KeyError: |
|
self.state_combination = 'sum' |
|
|
|
try: |
|
self.num_encoders = int(config['num_encoders']) |
|
except KeyError: |
|
self.num_encoders = 4 |
|
|
|
try: |
|
self.base_num_channels = int(config['base_num_channels']) |
|
except KeyError: |
|
self.base_num_channels = 32 |
|
|
|
try: |
|
self.num_residual_blocks = int(config['num_residual_blocks']) |
|
except KeyError: |
|
self.num_residual_blocks = 2 |
|
|
|
try: |
|
self.recurrent_block_type = str(config['recurrent_block_type']) |
|
except KeyError: |
|
self.recurrent_block_type = 'convlstm' |
|
|
|
try: |
|
self.norm = str(config['norm']) |
|
except KeyError: |
|
self.norm = None |
|
|
|
try: |
|
self.use_upsample_conv = bool(config['use_upsample_conv']) |
|
except KeyError: |
|
self.use_upsample_conv = True |
|
|
|
try: |
|
self.every_x_rgb_frame = config['every_x_rgb_frame'] |
|
except KeyError: |
|
self.every_x_rgb_frame = 1 |
|
|
|
try: |
|
self.baseline = config['baseline'] |
|
except KeyError: |
|
self.baseline = False |
|
|
|
try: |
|
self.loss_composition = config['loss_composition'] |
|
except KeyError: |
|
self.loss_composition = False |
|
|
|
self.kernel_size = int(config.get('kernel_size', 5)) |
|
self.gpu = torch.device('cuda:' + str(config['gpu'])) |
|
|
|
|
|
|