import torch import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch.conv import GINConv from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling class ApplyNodeFunc(nn.Module): """Update the node feature hv with MLP, BN and ReLU.""" def __init__(self, mlp): super(ApplyNodeFunc, self).__init__() self.mlp = mlp self.bn = nn.BatchNorm1d(self.mlp.output_dim) def forward(self, h): h = self.mlp(h) h = self.bn(h) h = F.relu(h) return h class MLP(nn.Module): """MLP with linear output""" def __init__(self, num_layers, input_dim, hidden_dim, output_dim): """MLP layers construction Paramters --------- num_layers: int The number of linear layers input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers output_dim: int The number of classes for prediction """ super(MLP, self).__init__() self.linear_or_not = True # default is linear model self.num_layers = num_layers self.output_dim = output_dim if num_layers < 1: raise ValueError("number of layers should be positive!") elif num_layers == 1: # Linear model self.linear = nn.Linear(input_dim, output_dim) else: # Multi-layer model self.linear_or_not = False self.linears = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() self.linears.append(nn.Linear(input_dim, hidden_dim)) for layer in range(num_layers - 2): self.linears.append(nn.Linear(hidden_dim, hidden_dim)) self.linears.append(nn.Linear(hidden_dim, output_dim)) for layer in range(num_layers - 1): self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) def forward(self, x): if self.linear_or_not: # If linear model return self.linear(x) else: # If MLP h = x for i in range(self.num_layers - 1): h = F.relu(self.batch_norms[i](self.linears[i](h))) return self.linears[-1](h) class GIN(nn.Module): """GIN model""" def __init__(self, input_dim, hidden_dim,num_layers, num_mlp_layers=2, dropout=0.1, learn_eps=False, neighbor_pooling_type='sum',JK='sum'): """model parameters setting Paramters --------- num_layers: int The number of linear layers in the neural network num_mlp_layers: int The number of linear layers in mlps input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers dropout: float dropout ratio on the final linear layer learn_eps: boolean If True, learn epsilon to distinguish center nodes from neighbors If False, aggregate neighbors and center nodes altogether. neighbor_pooling_type: str how to aggregate neighbors (sum, mean, or max) """ super(GIN, self).__init__() self.num_layers = num_layers self.learn_eps = learn_eps # List of MLPs self.ginlayers = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for layer in range(self.num_layers - 1): if layer == 0: mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim) else: mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) self.ginlayers.append( GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # Linear function for graph poolings of output of each layer # which maps the output of different layers into a prediction score self.drop = nn.Dropout(dropout) self.JK = JK def forward(self, g, Perturb=None): # list of hidden representation at each layer (including input) h = g.ndata.pop('h').float() hidden_rep = [] for i in range(self.num_layers - 1): if i == 0 and Perturb is not None: h = h + Perturb h = self.ginlayers[i](g, h) h = self.batch_norms[i](h) h = F.relu(h) h = self.drop(h) hidden_rep.append(h) if self.JK=='sum': hidden_rep = [h.unsqueeze(0) for h in hidden_rep] return torch.sum(torch.cat(hidden_rep, dim=0), dim=0) elif self.JK=='max': hidden_rep = [h.unsqueeze(0) for h in hidden_rep] return torch.max(torch.cat(hidden_rep, dim = 0), dim = 0)[0] elif self.JK=='concat': return torch.cat(hidden_rep, dim = 1) elif self.JK=='last': return hidden_rep[-1]