import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F from torch.nn.parameter import Parameter import numpy as np import networkx as nx from torch_geometric.nn import GCNConv, GATConv, SGConv, APPNP #from gnn_layers import BatchAGC, BatchFiGNN, BatchGAT from CatGCN.gnn_layers import BatchAGC, BatchFiGNN, BatchGAT #from pna_layer import PNAConv from CatGCN.pna_layer import PNAConv #from gcnii_layer import GCNIIConv from CatGCN.gcnii_layer import GCNIIConv class StackedGNN(nn.Module): """ Multi-layer GNN model. """ def __init__(self, args, field_count, field_size, output_channels): """ :param args: Arguments object. :param field_count: Number of fields. :param field_size: Number of sampled fields for each user. :param output_channels: Number of target classes. """ super(StackedGNN, self).__init__() self.args = args if self.args.grn_units != 'none': self.grn_units = [args.field_dim] + [int(x) for x in args.grn_units.strip().split(",")] + [output_channels] else: self.grn_units = [args.field_dim] + [output_channels] if self.args.nfm_units != 'none': self.nfm_units = [args.field_dim] + [int(x) for x in args.nfm_units.strip().split(",")] + [output_channels] else: self.nfm_units = [args.field_dim] + [output_channels] self.input_channels = args.field_dim self.output_channels = output_channels # For Baseline if self.args.gnn_units != 'none': self.gnn_units = [self.input_channels] + [int(x) for x in args.gnn_units.strip().split(",")] + [self.output_channels] else: self.gnn_units = [self.input_channels] + [self.output_channels] self.field_count = field_count self.field_size = field_size self.field_embedding = nn.Embedding(field_count, args.field_dim) self.field_embedding.weight.requires_grad = True self._setup_layers() def _setup_layers(self): """ Creating the layers based on the args. """ # Categorical feature interaction modeling ''' Global interaction modeling ''' if self.args.graph_refining == 'agc': self.grn = BatchAGC(self.args.field_dim, self.args.field_dim) self.num_grn_layer = len(self.grn_units) - 1 self.grn_layer_stack = nn.ModuleList() for i in range(self.num_grn_layer): self.grn_layer_stack.append( nn.Linear(self.grn_units[i], self.grn_units[i + 1], bias=True)) elif self.args.graph_refining == 'gat': n_heads = [int(x) for x in self.args.multi_heads.strip().split(",")] attn_dropout = 0. # attn_dropout = self.args.dropout self.gat_units = [int(x) for x in self.args.gat_units.strip().split(",")] self.num_gat_layer = len(self.gat_units) - 1 self.gat_layer_stack = nn.ModuleList() for i in range(self.num_gat_layer): f_in = self.gat_units[i] * n_heads[i - 1] if i else self.gat_units[i] self.gat_layer_stack.append( BatchGAT( n_heads[i], f_in=f_in, f_out=self.gat_units[i + 1], attn_dropout=attn_dropout)) self.num_grn_layer = len(self.grn_units) - 1 self.grn_layer_stack = nn.ModuleList() for i in range(self.num_grn_layer): self.grn_layer_stack.append( nn.Linear(self.grn_units[i], self.grn_units[i + 1], bias=True)) elif self.args.graph_refining == 'cosimi': self.num_grn_layer = len(self.grn_units) - 1 self.grn_layer_stack = nn.ModuleList() for i in range(self.num_grn_layer): self.grn_layer_stack.append( nn.Linear(self.grn_units[i], self.grn_units[i + 1], bias=True)) ''' Local interaction modeling ''' if self.args.bi_interaction == 'nfm': self.num_nfm_layer = len(self.nfm_units) - 1 self.nfm_layer_stack = nn.ModuleList() for i in range(self.num_nfm_layer): self.nfm_layer_stack.append( nn.Linear(self.nfm_units[i], self.nfm_units[i + 1], bias=True)) # GNN Layer if self.args.graph_layer == 'gcn': self.gnn_layers = nn.ModuleList() for i, _ in enumerate(self.gnn_units[:-1]): self.gnn_layers.append(GCNConv(self.gnn_units[i], self.gnn_units[i+1])) elif self.args.graph_layer == 'gat_1': self.gnn_layers = GATConv(self.input_channels, self.output_channels, heads=1, concat=True, negative_slope=0.2, dropout=self.args.dropout, bias=True) elif self.args.graph_layer == 'gat_2': n_heads = 8 self.gnn_layers_1 = GATConv(self.input_channels, self.gnn_units[1], heads=n_heads, concat=True, negative_slope=0.2, dropout=0, bias=True) self.gnn_layers_2 = GATConv(self.gnn_units[1]*n_heads, self.output_channels, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True) elif self.args.graph_layer == 'sgc': self.gnn_layers = SGConv(self.input_channels, self.output_channels, K=self.args.gnn_hops, cached=False) elif self.args.graph_layer == 'appnp': self.num_mlp_layer = len(self.gnn_units) - 1 self.mlp_layer_stack = nn.ModuleList() for i in range(self.num_mlp_layer): self.mlp_layer_stack.append( nn.Linear(self.gnn_units[i], self.gnn_units[i + 1], bias=True)) self.gnn_layers = APPNP(K=10, alpha=0.1, bias=True) elif self.args.graph_layer == 'cat-appnp': self.gnn_layers = APPNP(K=10, alpha=0.1, bias=True) elif self.args.graph_layer == 'gcnii_F': self.num_gnn_layer = self.args.gnn_hops self.lin_layer_1 = nn.Linear(self.input_channels, self.gnn_units[1], bias=True) self.gnn_layers = nn.ModuleList() for layer in range(self.num_gnn_layer): self.gnn_layers.append(GCNIIConv(self.gnn_units[1], alpha=self.args.alpha, theta=self.args.theta, layer=layer+1, shared_weights=False)) self.lin_layer_2 = nn.Linear(self.gnn_units[1], self.output_channels, bias=True) elif self.args.graph_layer == 'gcnii_T': self.num_gnn_layer = self.args.gnn_hops self.lin_layer_1 = nn.Linear(self.input_channels, self.gnn_units[1], bias=True) self.gnn_layers = nn.ModuleList() for layer in range(self.num_gnn_layer): self.gnn_layers.append(GCNIIConv(self.gnn_units[1], alpha=self.args.alpha, theta=self.args.theta, layer=layer+1, shared_weights=True)) self.lin_layer_2 = nn.Linear(self.gnn_units[1], self.output_channels, bias=True) elif self.args.graph_layer == 'cross_1': self.mlp_layers_1 = nn.Linear(self.input_channels, self.output_channels, bias=False) self.mlp_layers_2 = nn.Linear(self.input_channels, self.output_channels, bias=False) self.gnn_layers = PNAConv(K=1, cached=False) elif self.args.graph_layer == 'cross_2': self.mlp_layers_11 = nn.Linear(self.input_channels, self.gnn_units[1], bias=False) self.mlp_layers_12 = nn.Linear(self.input_channels, self.gnn_units[1], bias=False) self.gnn_layers_1 = PNAConv(K=1, cached=False) self.mlp_layers_21 = nn.Linear(self.gnn_units[1], self.output_channels, bias=False) self.mlp_layers_22 = nn.Linear(self.gnn_units[1], self.output_channels, bias=False) self.gnn_layers_2 = PNAConv(K=1, cached=False) elif self.args.graph_layer == 'fignn': self.fi_layers = BatchFiGNN(self.input_channels, self.gnn_units[1], self.output_channels) self.gnn_layers = PNAConv(K=self.args.gnn_hops, cached=False) elif self.args.graph_layer == 'pna': self.gnn_layers = PNAConv(K=self.args.gnn_hops, cached=False) def forward(self, edges, field_index, field_adjs): """ Making a forward pass. :param edges: Edge list LongTensor. :parm field_index: User-field index matrix. :parm field_adjs: Normalized adjacency matrix with probe coefficient. :return predictions: Prediction matrix output FLoatTensor. """ raw_field_feature = self.field_embedding(field_index) # Categorical feature interaction modeling ''' Global interaction modeling ''' field_feature = raw_field_feature if self.args.graph_refining == 'agc': field_feature = self.grn(field_feature, field_adjs.float()) field_feature = F.relu(field_feature) field_feature = F.dropout(field_feature, self.args.dropout, training=self.training) if self.args.aggr_pooling == 'mean': user_feature = torch.mean(field_feature, dim=-2) for i, grn_layer in enumerate(self.grn_layer_stack): user_feature = grn_layer(user_feature) if i + 1 < self.num_grn_layer: user_feature = F.relu(user_feature) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) user_gnn_feature = user_feature elif self.args.graph_refining == 'gat': bs, n = field_adjs.size()[:2] for i, gat_layer in enumerate(self.gat_layer_stack): field_feature = gat_layer(field_feature, field_adjs.byte()) if i + 1 == self.num_gat_layer: field_feature = field_feature.mean(dim=1) else: field_feature = F.elu(field_feature.transpose(1, 2).contiguous().view(bs, n, -1)) field_feature = F.dropout(field_feature, self.args.dropout, training=self.training) if self.args.aggr_pooling == 'mean': user_feature = torch.mean(field_feature, dim=-2) for i, grn_layer in enumerate(self.grn_layer_stack): user_feature = grn_layer(user_feature) if i + 1 < self.num_grn_layer: user_feature = F.relu(user_feature) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) user_gnn_feature = user_feature elif self.args.graph_refining == 'cosimi': similarity_mat = torch.bmm(field_feature, field_feature.permute(0, 2, 1)) feature_norm = torch.sqrt(torch.sum(torch.mul(field_feature, field_feature), dim=-1)).unsqueeze(2) cosine_distance = torch.div(similarity_mat, torch.mul(feature_norm, feature_norm.permute(0, 2, 1))) field_feature = torch.bmm(cosine_distance, field_feature) if self.args.aggr_pooling == 'mean': user_feature = torch.mean(field_feature, dim=-2) for i, grn_layer in enumerate(self.grn_layer_stack): user_feature = grn_layer(user_feature) if i + 1 < self.num_grn_layer: user_feature = F.relu(user_feature) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) user_gnn_feature = user_feature ''' Local interaction modeling ''' field_feature = raw_field_feature if self.args.bi_interaction == 'nfm': # sum-square-part summed_field_feature = torch.sum(field_feature, 1) square_summed_field_feature = summed_field_feature ** 2 # squre-sum-part squared_field_feature = field_feature ** 2 sum_squared_field_feature = torch.sum(squared_field_feature, 1) # second order user_feature = 0.5 * (square_summed_field_feature - sum_squared_field_feature) # deep part for i, nfm_layer in enumerate(self.nfm_layer_stack): user_feature = nfm_layer(user_feature) if i + 1 < self.num_nfm_layer: user_feature = F.relu(user_feature) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) user_nfm_feature = user_feature # Aggregation if self.args.aggr_style == 'sum': user_feature = self.args.balance_ratio*user_gnn_feature + \ (1-self.args.balance_ratio)*user_nfm_feature if self.args.graph_refining == 'none' and self.args.bi_interaction == 'none': user_feature = torch.mean(raw_field_feature, dim=-2) # GNN Layer if self.args.graph_layer == 'gcn': for i, _ in enumerate(self.gnn_units[:-2]): user_feature = F.relu(self.gnn_layers[i](user_feature, edges)) if i > 1: user_feature = F.dropout(user_feature, p=self.args.dropout, training=self.training) user_feature = self.gnn_layers[i+1](user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'gat_1': user_feature = self.gnn_layers(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'gat_2': user_feature = F.elu(self.gnn_layers_1(user_feature, edges)) user_feature = F.dropout(user_feature, p=self.args.dropout, training=self.training) user_feature = self.gnn_layers_2(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'sgc': user_feature = self.gnn_layers(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'appnp': for i, mlp_layer in enumerate(self.mlp_layer_stack): user_feature = mlp_layer(user_feature) if i + 1 < self.num_mlp_layer: user_feature = F.relu(user_feature) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) user_feature = self.gnn_layers(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'cat-appnp': user_feature = self.gnn_layers(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'gcnii_F' or self.args.graph_layer == 'gcnii_T': user_feature = self.lin_layer_1(user_feature) user_feature = F.relu(user_feature) user_feature = user_feature_0 = F.dropout(user_feature, self.args.dropout, training=self.training) for i, gnn_layer in enumerate(self.gnn_layers): user_feature = gnn_layer(user_feature, user_feature_0, edges) if i + 1 < self.num_gnn_layer: user_feature = F.relu(user_feature) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) user_feature = self.lin_layer_2(user_feature) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'cross_1': alpha = 1 user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) x_1 = self.mlp_layers_1(user_feature) x_2 = self.mlp_layers_1(user_feature) x_sec_ord = torch.mul(x_1, x_2) * alpha user_feature = x_1 + x_2 + x_sec_ord user_feature = self.gnn_layers(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'cross_2': alpha = 1 user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) x_11 = self.mlp_layers_11(user_feature) x_12 = self.mlp_layers_12(user_feature) x_sec_ord_1 = torch.mul(x_11, x_12) * alpha user_feature = x_11 + x_12 + x_sec_ord_1 user_feature = self.gnn_layers_1(user_feature, edges) user_feature = F.dropout(user_feature, self.args.dropout, training=self.training) x_21 = self.mlp_layers_21(user_feature) x_22 = self.mlp_layers_22(user_feature) x_sec_ord_2 = torch.mul(x_21, x_22) * alpha user_feature = x_21 + x_22 + x_sec_ord_2 user_feature = self.gnn_layers_2(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'fignn': user_feature = self.fi_layers(raw_field_feature, field_adjs.float(), self.args.num_steps) user_feature = self.gnn_layers(user_feature, edges) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'pna': user_feature = self.gnn_layers(user_feature, edges) #print('user_feature pna in forward pass:', user_feature) predictions = F.log_softmax(user_feature, dim=1) elif self.args.graph_layer == 'none': predictions = F.log_softmax(user_feature, dim=1) return predictions