Spaces:
Running
Running
import torch.nn as nn | |
import torch.nn.functional as F | |
from dgl.nn.pytorch import GraphConv | |
__all__ = ['GCN'] | |
# pylint: disable=W0221, C0103 | |
class GCNLayer(nn.Module): | |
r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks | |
<https://arxiv.org/abs/1609.02907>`__ | |
Parameters | |
---------- | |
in_feats : int | |
Number of input node features. | |
out_feats : int | |
Number of output node features. | |
gnn_norm : str | |
The message passing normalizer, which can be `'right'`, `'both'` or `'none'`. The | |
`'right'` normalizer divides the aggregated messages by each node's in-degree. | |
The `'both'` normalizer corresponds to the symmetric adjacency normalization in | |
the original GCN paper. The `'none'` normalizer simply sums the messages. | |
Default to be 'none'. | |
activation : activation function | |
Default to be None. | |
residual : bool | |
Whether to use residual connection, default to be True. | |
batchnorm : bool | |
Whether to use batch normalization on the output, | |
default to be True. | |
dropout : float | |
The probability for dropout. Default to be 0., i.e. no | |
dropout is performed. | |
""" | |
def __init__(self, in_feats, out_feats, gnn_norm='none', activation=None, | |
residual=True, batchnorm=True, dropout=0.): | |
super(GCNLayer, self).__init__() | |
self.activation = activation | |
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats, | |
norm=gnn_norm, activation=activation,allow_zero_in_degree=True) | |
self.dropout = nn.Dropout(dropout) | |
self.residual = residual | |
if residual: | |
self.res_connection = nn.Linear(in_feats, out_feats) | |
self.bn = batchnorm | |
if batchnorm: | |
self.bn_layer = nn.BatchNorm1d(out_feats) | |
def reset_parameters(self): | |
"""Reinitialize model parameters.""" | |
self.graph_conv.reset_parameters() | |
if self.residual: | |
self.res_connection.reset_parameters() | |
if self.bn: | |
self.bn_layer.reset_parameters() | |
def forward(self, g, feats): | |
"""Update node representations. | |
Parameters | |
---------- | |
g : DGLGraph | |
DGLGraph for a batch of graphs | |
feats : FloatTensor of shape (N, M1) | |
* N is the total number of nodes in the batch of graphs | |
* M1 is the input node feature size, which must match in_feats in initialization | |
Returns | |
------- | |
new_feats : FloatTensor of shape (N, M2) | |
* M2 is the output node feature size, which must match out_feats in initialization | |
""" | |
new_feats = self.graph_conv(g, feats) | |
if self.residual: | |
res_feats = self.activation(self.res_connection(feats)) | |
new_feats = new_feats + res_feats | |
new_feats = self.dropout(new_feats) | |
if self.bn: | |
new_feats = self.bn_layer(new_feats) | |
return new_feats | |
class GCN(nn.Module): | |
r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks | |
<https://arxiv.org/abs/1609.02907>`__ | |
Parameters | |
---------- | |
in_feats : int | |
Number of input node features. | |
hidden_feats : list of int | |
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer. | |
``len(hidden_feats)`` equals the number of GCN layers. By default, we use | |
``[64, 64]``. | |
gnn_norm : list of str | |
``gnn_norm[i]`` gives the message passing normalizer for the i-th GCN layer, which | |
can be `'right'`, `'both'` or `'none'`. The `'right'` normalizer divides the aggregated | |
messages by each node's in-degree. The `'both'` normalizer corresponds to the symmetric | |
adjacency normalization in the original GCN paper. The `'none'` normalizer simply sums | |
the messages. ``len(gnn_norm)`` equals the number of GCN layers. By default, we use | |
``['none', 'none']``. | |
activation : list of activation functions or None | |
If not None, ``activation[i]`` gives the activation function to be used for | |
the i-th GCN layer. ``len(activation)`` equals the number of GCN layers. | |
By default, ReLU is applied for all GCN layers. | |
residual : list of bool | |
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer. | |
``len(residual)`` equals the number of GCN layers. By default, residual connection | |
is performed for each GCN layer. | |
batchnorm : list of bool | |
``batchnorm[i]`` decides if batch normalization is to be applied on the output of | |
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default, | |
batch normalization is applied for all GCN layers. | |
dropout : list of float | |
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer. | |
``len(dropout)`` equals the number of GCN layers. By default, no dropout is | |
performed for all layers. | |
""" | |
def __init__(self, in_feats, hidden_feats=None, gnn_norm=None, activation=None, | |
residual=None, batchnorm=None, dropout=None): | |
super(GCN, self).__init__() | |
if hidden_feats is None: | |
hidden_feats = [64, 64] | |
n_layers = len(hidden_feats) | |
if gnn_norm is None: | |
gnn_norm = ['none' for _ in range(n_layers)] | |
if activation is None: | |
activation = [F.relu for _ in range(n_layers)] | |
if residual is None: | |
residual = [True for _ in range(n_layers)] | |
if batchnorm is None: | |
batchnorm = [True for _ in range(n_layers)] | |
if dropout is None: | |
dropout = [0. for _ in range(n_layers)] | |
lengths = [len(hidden_feats), len(gnn_norm), len(activation), | |
len(residual), len(batchnorm), len(dropout)] | |
assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, gnn_norm, ' \ | |
'activation, residual, batchnorm and dropout to ' \ | |
'be the same, got {}'.format(lengths) | |
self.hidden_feats = hidden_feats | |
self.gnn_layers = nn.ModuleList() | |
for i in range(n_layers): | |
self.gnn_layers.append(GCNLayer(in_feats, hidden_feats[i], gnn_norm[i], activation[i], | |
residual[i], batchnorm[i], dropout[i])) | |
in_feats = hidden_feats[i] | |
def reset_parameters(self): | |
"""Reinitialize model parameters.""" | |
for gnn in self.gnn_layers: | |
gnn.reset_parameters() | |
def forward(self, g, Perturb=None): | |
"""Update node representations. | |
Parameters | |
---------- | |
g : DGLGraph | |
DGLGraph for a batch of graphs | |
feats : FloatTensor of shape (N, M1) | |
* N is the total number of nodes in the batch of graphs | |
* M1 is the input node feature size, which equals in_feats in initialization | |
Returns | |
------- | |
feats : FloatTensor of shape (N, M2) | |
* N is the total number of nodes in the batch of graphs | |
* M2 is the output node representation size, which equals | |
hidden_sizes[-1] in initialization. | |
""" | |
feats = g.ndata.pop('h').float() | |
index = 0 | |
for gnn in self.gnn_layers: | |
if index == 0 and Perturb is not None: | |
feats = feats + Perturb | |
feats = gnn(g, feats) | |
index += 1 | |
return feats | |