MEIRa / pytorch_utils /modules.py
KawshikManikantan's picture
upload_trial
98e2ea5
raw
history blame
995 Bytes
import torch.nn as nn
class MLP(nn.Module):
def __init__(
self,
input_size,
hidden_size,
output_size,
num_hidden_layers=1,
bias=False,
drop_module=None,
):
super(MLP, self).__init__()
self.layer_list = []
self.activation = nn.ReLU()
self.drop_module = drop_module
self.num_hidden_layers = num_hidden_layers
cur_output_size = input_size
for i in range(num_hidden_layers):
self.layer_list.append(nn.Linear(cur_output_size, hidden_size, bias=bias))
self.layer_list.append(self.activation)
if self.drop_module is not None:
self.layer_list.append(self.drop_module)
cur_output_size = hidden_size
self.layer_list.append(nn.Linear(cur_output_size, output_size, bias=bias))
self.fc_layers = nn.Sequential(*self.layer_list)
def forward(self, mlp_input):
return self.fc_layers(mlp_input)