libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
229 Bytes
from torch import cat
from deepscreen.models.components.mlp import MLP
class ConcatMLP(MLP):
def forward(self, *inputs):
x = cat([*inputs], 1)
for module in self:
x = module(x)
return x