rightnow / SpikeT /model /model.py
zzzzzeee's picture
Upload 101 files
5fc3d65 verified
from SpikeT.base import BaseModel
import torch.nn as nn
import torch
from os.path import join
from SpikeT.model.submodules import \
ConvLSTM, ResidualBlock, ConvLayer, \
UpsampleConvLayer, TransposedConvLayer
def skip_concat(x1, x2):
return torch.cat([x1, x2], dim=1)
def skip_sum(x1, x2):
return x1 + x2
def identity(x1, x2=None):
return x1
class BaseERGB2Depth(BaseModel):
def __init__(self, config):
super().__init__(config)
assert ('num_bins_rgb' in config)
self.num_bins_rgb = int(config['num_bins_rgb'])
assert ('num_bins_events' in config)
self.num_bins_events = int(config['num_bins_events'])
try:
self.skip_type = str(config['skip_type']) # 'sum'
except KeyError:
self.skip_type = 'sum'
try:
self.state_combination = str(config['state_combination']) # none
except KeyError:
self.state_combination = 'sum'
try:
self.num_encoders = int(config['num_encoders']) # 3
except KeyError:
self.num_encoders = 4
try:
self.base_num_channels = int(config['base_num_channels']) # 32
except KeyError:
self.base_num_channels = 32
try:
self.num_residual_blocks = int(config['num_residual_blocks']) # 2
except KeyError:
self.num_residual_blocks = 2
try:
self.recurrent_block_type = str(config['recurrent_block_type']) # none
except KeyError:
self.recurrent_block_type = 'convlstm'
try:
self.norm = str(config['norm']) # 'none'
except KeyError:
self.norm = None
try:
self.use_upsample_conv = bool(config['use_upsample_conv']) # True
except KeyError:
self.use_upsample_conv = True
try:
self.every_x_rgb_frame = config['every_x_rgb_frame'] # 1
except KeyError:
self.every_x_rgb_frame = 1
try:
self.baseline = config['baseline'] # e
except KeyError:
self.baseline = False
try:
self.loss_composition = config['loss_composition'] # 'image'
except KeyError:
self.loss_composition = False
self.kernel_size = int(config.get('kernel_size', 5)) # 5
self.gpu = torch.device('cuda:' + str(config['gpu']))