from torch import cat from deepscreen.models.components.mlp import MLP class ConcatMLP(MLP): def forward(self, *inputs): x = cat([*inputs], 1) for module in self: x = module(x) return x