Spaces:
Running
Running
File size: 7,618 Bytes
3ad8be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
__all__ = ['GCN']
# pylint: disable=W0221, C0103
class GCNLayer(nn.Module):
r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
out_feats : int
Number of output node features.
gnn_norm : str
The message passing normalizer, which can be `'right'`, `'both'` or `'none'`. The
`'right'` normalizer divides the aggregated messages by each node's in-degree.
The `'both'` normalizer corresponds to the symmetric adjacency normalization in
the original GCN paper. The `'none'` normalizer simply sums the messages.
Default to be 'none'.
activation : activation function
Default to be None.
residual : bool
Whether to use residual connection, default to be True.
batchnorm : bool
Whether to use batch normalization on the output,
default to be True.
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
"""
def __init__(self, in_feats, out_feats, gnn_norm='none', activation=None,
residual=True, batchnorm=True, dropout=0.):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
norm=gnn_norm, activation=activation,allow_zero_in_degree=True)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def reset_parameters(self):
"""Reinitialize model parameters."""
self.graph_conv.reset_parameters()
if self.residual:
self.res_connection.reset_parameters()
if self.bn:
self.bn_layer.reset_parameters()
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output node feature size, which must match out_feats in initialization
"""
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
class GCN(nn.Module):
r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
gnn_norm : list of str
``gnn_norm[i]`` gives the message passing normalizer for the i-th GCN layer, which
can be `'right'`, `'both'` or `'none'`. The `'right'` normalizer divides the aggregated
messages by each node's in-degree. The `'both'` normalizer corresponds to the symmetric
adjacency normalization in the original GCN paper. The `'none'` normalizer simply sums
the messages. ``len(gnn_norm)`` equals the number of GCN layers. By default, we use
``['none', 'none']``.
activation : list of activation functions or None
If not None, ``activation[i]`` gives the activation function to be used for
the i-th GCN layer. ``len(activation)`` equals the number of GCN layers.
By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
"""
def __init__(self, in_feats, hidden_feats=None, gnn_norm=None, activation=None,
residual=None, batchnorm=None, dropout=None):
super(GCN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64]
n_layers = len(hidden_feats)
if gnn_norm is None:
gnn_norm = ['none' for _ in range(n_layers)]
if activation is None:
activation = [F.relu for _ in range(n_layers)]
if residual is None:
residual = [True for _ in range(n_layers)]
if batchnorm is None:
batchnorm = [True for _ in range(n_layers)]
if dropout is None:
dropout = [0. for _ in range(n_layers)]
lengths = [len(hidden_feats), len(gnn_norm), len(activation),
len(residual), len(batchnorm), len(dropout)]
assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, gnn_norm, ' \
'activation, residual, batchnorm and dropout to ' \
'be the same, got {}'.format(lengths)
self.hidden_feats = hidden_feats
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(GCNLayer(in_feats, hidden_feats[i], gnn_norm[i], activation[i],
residual[i], batchnorm[i], dropout[i]))
in_feats = hidden_feats[i]
def reset_parameters(self):
"""Reinitialize model parameters."""
for gnn in self.gnn_layers:
gnn.reset_parameters()
def forward(self, g, Perturb=None):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] in initialization.
"""
feats = g.ndata.pop('h').float()
index = 0
for gnn in self.gnn_layers:
if index == 0 and Perturb is not None:
feats = feats + Perturb
feats = gnn(g, feats)
index += 1
return feats
|