from torch import nn, cat class MLP1(nn.Sequential): def __init__(self, input_channels, hidden_channels: list[int], out_channels: int, activation: type[nn.Module] = nn.ReLU, dropout: float = 0.0): layers = [] num_layers = len(hidden_channels) + 1 dims = [input_channels] + hidden_channels + [out_channels] for i in range(num_layers): if i != (num_layers - 1): layers.append(nn.Linear(dims[i], dims[i+1])) layers.append(nn.Dropout(dropout)) layers.append(activation()) else: layers.append(nn.Linear(dims[i], dims[i+1])) super().__init__(*layers) class MLP2(nn.Sequential): def __init__(self, input_channels, hidden_channels: list[int], out_channels: int, dropout: float = 0.0): super().__init__() self.dropout = nn.Dropout(dropout) num_layers = len(hidden_channels) + 1 dims = [input_channels] + hidden_channels + [out_channels] self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(num_layers)]) def forward(self, x): for i, layer in enumerate(self.layers): if i == (len(self.layers) - 1): x = layer(x) else: x = nn.functional.relu(self.dropout(layer(x))) return x class LazyMLP(nn.Sequential): def __init__( self, out_channels: int, hidden_channels: list[int], activation: type[nn.Module] = nn.ReLU, dropout: float = 0.0 ): layers = [] for hidden_dim in hidden_channels: layers.append(nn.LazyLinear(out_features=hidden_dim)) layers.append(nn.Dropout(dropout)) layers.append(activation()) layers.append(nn.LazyLinear(out_features=out_channels)) super().__init__(*layers) class ConcatMLP(LazyMLP): def forward(self, *inputs): x = cat([*inputs], 1) x = super().forward(x) return x # class ConcatMLP(MLP1): # def forward(self, *inputs): # x = cat([*inputs], 1) # for module in self: # x = module(x) # return x