import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Resnet Blocks | |
class ResnetBlockFC(nn.Module): | |
''' Fully connected ResNet Block class. | |
Args: | |
size_in (int): input dimension | |
size_out (int): output dimension | |
size_h (int): hidden dimension | |
''' | |
def __init__(self, size_in, size_out=None, size_h=None): | |
super().__init__() | |
# Attributes | |
if size_out is None: | |
size_out = size_in | |
if size_h is None: | |
size_h = min(size_in, size_out) | |
self.size_in = size_in | |
self.size_h = size_h | |
self.size_out = size_out | |
# Submodules | |
self.fc_0 = nn.Linear(size_in, size_h) | |
self.fc_1 = nn.Linear(size_h, size_out) | |
self.actvn = nn.ReLU() | |
if size_in == size_out: | |
self.shortcut = None | |
else: | |
self.shortcut = nn.Linear(size_in, size_out, bias=False) | |
# Initialization | |
nn.init.zeros_(self.fc_1.weight) | |
def forward(self, x): | |
net = self.fc_0(self.actvn(x)) | |
dx = self.fc_1(self.actvn(net)) | |
if self.shortcut is not None: | |
x_s = self.shortcut(x) | |
else: | |
x_s = x | |
return x_s + dx |