File size: 2,448 Bytes
5fc3d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

from SpikeT.base import BaseModel
import torch.nn as nn
import torch
from os.path import join
from SpikeT.model.submodules import \
    ConvLSTM, ResidualBlock, ConvLayer, \
    UpsampleConvLayer, TransposedConvLayer


def skip_concat(x1, x2):
    return torch.cat([x1, x2], dim=1)


def skip_sum(x1, x2):
    return x1 + x2


def identity(x1, x2=None):
    return x1


class BaseERGB2Depth(BaseModel):
    def __init__(self, config):
        super().__init__(config)

        assert ('num_bins_rgb' in config)
        self.num_bins_rgb = int(config['num_bins_rgb'])  
        assert ('num_bins_events' in config)
        self.num_bins_events = int(config['num_bins_events'])  

        try:
            self.skip_type = str(config['skip_type'])   # 'sum'
        except KeyError:
            self.skip_type = 'sum'

        try:
            self.state_combination = str(config['state_combination'])   # none
        except KeyError:
            self.state_combination = 'sum'

        try:
            self.num_encoders = int(config['num_encoders']) # 3
        except KeyError:
            self.num_encoders = 4

        try:
            self.base_num_channels = int(config['base_num_channels'])   # 32
        except KeyError:
            self.base_num_channels = 32

        try:
            self.num_residual_blocks = int(config['num_residual_blocks'])   # 2
        except KeyError:
            self.num_residual_blocks = 2

        try:
            self.recurrent_block_type = str(config['recurrent_block_type']) # none
        except KeyError:
            self.recurrent_block_type = 'convlstm'

        try:
            self.norm = str(config['norm']) # 'none'
        except KeyError:
            self.norm = None

        try:
            self.use_upsample_conv = bool(config['use_upsample_conv'])  # True
        except KeyError:
            self.use_upsample_conv = True

        try:
            self.every_x_rgb_frame = config['every_x_rgb_frame']    # 1
        except KeyError:
            self.every_x_rgb_frame = 1

        try:
            self.baseline = config['baseline']  # e
        except KeyError:
            self.baseline = False

        try:
            self.loss_composition = config['loss_composition']  # 'image'
        except KeyError:
            self.loss_composition = False

        self.kernel_size = int(config.get('kernel_size', 5))    # 5
        self.gpu = torch.device('cuda:' + str(config['gpu']))