|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
class LinLayers(nn.Module): |
|
|
|
def __init__(self, args): |
|
super(LinLayers,self).__init__() |
|
|
|
in_dim= args.in_dim |
|
hidden_layers= args.hidden_layers |
|
activations=args.activations |
|
batchnorms=args.bns |
|
dropouts = args.dropouts |
|
|
|
|
|
assert len(hidden_layers) == len(activations) == len(batchnorms) == len(dropouts), 'dimensions mismatch!' |
|
|
|
|
|
layers=nn.ModuleList() |
|
|
|
if hidden_layers: |
|
old_dim=in_dim |
|
for idx,layer in enumerate(hidden_layers): |
|
sub_layers = nn.ModuleList() |
|
sub_layers.append(nn.Linear(old_dim,layer)) |
|
if batchnorms[idx] : sub_layers.append(nn.BatchNorm1d(num_features=layer)) |
|
if activations[idx] : sub_layers.append(activations[idx]) |
|
if dropouts[idx] : sub_layers.append(nn.Dropout(p=dropouts[idx])) |
|
old_dim = layer |
|
|
|
sub_layers = nn.Sequential(*sub_layers) |
|
|
|
layers.append(sub_layers) |
|
|
|
|
|
|
|
else: |
|
layers.append(nn.Linear(in_dim,out_dim)) |
|
if batchnorms : layers.append(nn.BatchNorm1d(num_features=out_dim)) |
|
if activations : layers.append(activations) |
|
if dropouts : layers.append(nn.Dropout(p=dropouts)) |
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
|
|
def forward(self,x): |
|
|
|
x = self.layers(x) |
|
|
|
return x |
|
|
|
''' |
|
def _check_dimensions(self): |
|
if isinstance(self.hidden_layers,list) : |
|
assert len(self.hidden_layers)==len(self.activations) |
|
assert len(self.hidden_layers)==len(self.batchnorms) |
|
assert len(self.hidden_layers)==len(self.dropouts) |
|
''' |
|
|
|
|
|
|
|
|
|
|