Spaces:
Runtime error
Runtime error
File size: 961 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torch.nn as nn
from mmpl.registry import MODELS
@MODELS.register_module()
class LinearProj(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=None, num_inner_layers=1):
super(LinearProj, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_inner_layers = num_inner_layers
if base_channels is None:
base_channels = out_channels
self.base_channels = base_channels
layers = [nn.Linear(self.in_channels, self.base_channels), nn.ReLU(inplace=True)]
for i in range(self.num_inner_layers):
layers.append(nn.Linear(self.base_channels, self.base_channels))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(self.base_channels, self.out_channels))
self.layers = nn.Sequential(*layers)
def forward(self, x):
x = self.layers(x)
return x
|