File size: 2,350 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch import nn, cat


class MLP1(nn.Sequential):
    def __init__(self,
                 input_channels,
                 hidden_channels: list[int],
                 out_channels: int,
                 activation: type[nn.Module] = nn.ReLU,
                 dropout: float = 0.0):
        layers = []
        num_layers = len(hidden_channels) + 1
        dims = [input_channels] + hidden_channels + [out_channels]
        for i in range(num_layers):
            if i != (num_layers - 1):
                layers.append(nn.Linear(dims[i], dims[i+1]))
                layers.append(nn.Dropout(dropout))
                layers.append(activation())
            else:
                layers.append(nn.Linear(dims[i], dims[i+1]))

        super().__init__(*layers)


class MLP2(nn.Sequential):
    def __init__(self,
                 input_channels,
                 hidden_channels: list[int],
                 out_channels: int,
                 dropout: float = 0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        num_layers = len(hidden_channels) + 1
        dims = [input_channels] + hidden_channels + [out_channels]
        self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(num_layers)])

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i == (len(self.layers) - 1):
                x = layer(x)
            else:
                x = nn.functional.relu(self.dropout(layer(x)))
        return x


class LazyMLP(nn.Sequential):
    def __init__(
            self,
            out_channels: int,
            hidden_channels: list[int],
            activation: type[nn.Module] = nn.ReLU,
            dropout: float = 0.0
    ):
        layers = []
        for hidden_dim in hidden_channels:
            layers.append(nn.LazyLinear(out_features=hidden_dim))
            layers.append(nn.Dropout(dropout))
            layers.append(activation())
        layers.append(nn.LazyLinear(out_features=out_channels))

        super().__init__(*layers)


class ConcatMLP(LazyMLP):
    def forward(self, *inputs):
        x = cat([*inputs], 1)
        x = super().forward(x)
        return x


# class ConcatMLP(MLP1):
#     def forward(self, *inputs):
#         x = cat([*inputs], 1)
#         for module in self:
#             x = module(x)
#         return x