from torch import cat, nn import torch.nn.functional as F from torch.nn import Sequential, Linear, ReLU from torch_geometric.nn import GINConv, global_add_pool class GIN(nn.Module): r""" From `GraphDTA `_ (Nguyen et al., 2020), based on `Graph Isomorphism Network `_ (Xu et al., 2019) """ def __init__( self, num_features: int, out_channels: int, dropout: float ): super().__init__() dim = 32 self.dropout = dropout self.relu = nn.ReLU() nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) self.conv1 = GINConv(nn1) self.bn1 = nn.BatchNorm1d(dim) nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv2 = GINConv(nn2) self.bn2 = nn.BatchNorm1d(dim) nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv3 = GINConv(nn3) self.bn3 = nn.BatchNorm1d(dim) nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv4 = GINConv(nn4) self.bn4 = nn.BatchNorm1d(dim) nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv5 = GINConv(nn5) self.bn5 = nn.BatchNorm1d(dim) self.fc1_xd = Linear(dim, out_channels) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = self.bn1(x) x = F.relu(self.conv2(x, edge_index)) x = self.bn2(x) x = F.relu(self.conv3(x, edge_index)) x = self.bn3(x) x = F.relu(self.conv4(x, edge_index)) x = self.bn4(x) x = F.relu(self.conv5(x, edge_index)) x = self.bn5(x) x = global_add_pool(x, batch) x = F.relu(self.fc1_xd(x)) x = F.dropout(x, p=self.dropout, training=self.training) return x