import torch import torch.nn as nn import torch.nn.functional as F class HighwayNetwork(nn.Module): def __init__(self, size): super().__init__() self.W1 = nn.Linear(size, size) self.W2 = nn.Linear(size, size) self.W1.bias.data.fill_(0.) def forward(self, x): x1 = self.W1(x) x2 = self.W2(x) g = torch.sigmoid(x2) y = g * F.relu(x1) + (1. - g) * x return y