import torch import torch.nn as nn from torch_geometric.nn.aggr import ( AttentionalAggregation, GraphMultisetTransformer, MaxAggregation, MeanAggregation, SetTransformerAggregation, ) class CatAggregation(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten(1, 2) def forward(self, x, index=None): return self.flatten(x) class HeterogeneousAggregator(nn.Module): def __init__( self, input_dim, hidden_dim, output_dim, pooling_method, pooling_layer_idx, input_channels, num_classes, ): super().__init__() self.pooling_method = pooling_method self.pooling_layer_idx = pooling_layer_idx self.input_channels = input_channels self.num_classes = num_classes if pooling_layer_idx == "all": self._pool_layer_idx_fn = self.get_all_layer_indices elif pooling_layer_idx == "last": self._pool_layer_idx_fn = self.get_last_layer_indices elif isinstance(pooling_layer_idx, int): self._pool_layer_idx_fn = self.get_nth_layer_indices else: raise ValueError(f"Unknown pooling layer index {pooling_layer_idx}") if pooling_method == "mean": self.pool = MeanAggregation() elif pooling_method == "max": self.pool = MaxAggregation() elif pooling_method == "cat": self.pool = CatAggregation() elif pooling_method == "attentional_aggregation": self.pool = AttentionalAggregation( gate_nn=nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, 1), ), nn=nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, output_dim), ), ) elif pooling_method == "set_transformer": self.pool = SetTransformerAggregation( input_dim, heads=8, num_encoder_blocks=4, num_decoder_blocks=4 ) elif pooling_method == "graph_multiset_transformer": self.pool = GraphMultisetTransformer(input_dim, k=8, heads=8) else: raise ValueError(f"Unknown pooling method {pooling_method}") def get_last_layer_indices( self, x, layer_layouts, node_mask=None, return_dense=False ): batch_size = x.shape[0] device = x.device # NOTE: node_mask needs to exist in the heterogeneous case only if node_mask is None: node_mask = torch.ones_like(x[..., 0], dtype=torch.bool, device=device) valid_layer_indices = ( torch.arange(node_mask.shape[1], device=device)[None, :] * node_mask ) last_layer_indices = valid_layer_indices.topk( k=self.num_classes, dim=1 ).values.fliplr() if return_dense: return torch.arange(batch_size, device=device)[:, None], last_layer_indices batch_indices = torch.arange(batch_size, device=device).repeat_interleave( self.num_classes ) return batch_indices, last_layer_indices.flatten() def get_nth_layer_indices( self, x, layer_layouts, node_mask=None, return_dense=False ): batch_size = x.shape[0] device = x.device cum_layer_layout = [ torch.cumsum(torch.tensor([0] + layer_layout), dim=0) for layer_layout in layer_layouts ] layer_sizes = torch.tensor( [layer_layout[self.pooling_layer_idx] for layer_layout in layer_layouts], dtype=torch.long, device=device, ) batch_indices = torch.arange(batch_size, device=device).repeat_interleave( layer_sizes ) layer_indices = torch.cat( [ torch.arange( layout[self.pooling_layer_idx], layout[self.pooling_layer_idx + 1], device=device, ) for layout in cum_layer_layout ] ) return batch_indices, layer_indices def get_all_layer_indices( self, x, layer_layouts, node_mask=None, return_dense=False ): """Imitate flattening with indexing""" batch_size, num_nodes = x.shape[:2] device = x.device batch_indices = torch.arange(batch_size, device=device).repeat_interleave( num_nodes ) layer_indices = torch.arange(num_nodes, device=device).repeat(batch_size) return batch_indices, layer_indices def forward(self, x, layer_layouts, node_mask=None): # NOTE: `cat` only works with `pooling_layer_idx == "last"` return_dense = self.pooling_method == "cat" and self.pooling_layer_idx == "last" batch_indices, layer_indices = self._pool_layer_idx_fn( x, layer_layouts, node_mask=node_mask, return_dense=return_dense ) flat_x = x[batch_indices, layer_indices] return self.pool(flat_x, index=batch_indices) class HomogeneousAggregator(nn.Module): def __init__( self, pooling_method, pooling_layer_idx, layer_layout, ): super().__init__() self.pooling_method = pooling_method self.pooling_layer_idx = pooling_layer_idx self.layer_layout = layer_layout def forward(self, node_features, edge_features): if self.pooling_method == "mean" and self.pooling_layer_idx == "all": graph_features = node_features.mean(dim=1) elif self.pooling_method == "max" and self.pooling_layer_idx == "all": graph_features = node_features.max(dim=1).values elif self.pooling_method == "mean" and self.pooling_layer_idx == "last": graph_features = node_features[:, -self.layer_layout[-1] :].mean(dim=1) elif self.pooling_method == "cat" and self.pooling_layer_idx == "last": graph_features = node_features[:, -self.layer_layout[-1] :].flatten(1, 2) elif self.pooling_method == "mean" and isinstance(self.pooling_layer_idx, int): graph_features = node_features[ :, self.layer_idx[self.pooling_layer_idx] : self.layer_idx[ self.pooling_layer_idx + 1 ], ].mean(dim=1) elif self.pooling_method == "cat_mean" and self.pooling_layer_idx == "all": graph_features = torch.cat( [ node_features[:, self.layer_idx[i] : self.layer_idx[i + 1]].mean( dim=1 ) for i in range(len(self.layer_layout)) ], dim=1, ) elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "all": graph_features = edge_features.mean(dim=(1, 2)) elif self.pooling_method == "max_edge" and self.pooling_layer_idx == "all": graph_features = edge_features.flatten(1, 2).max(dim=1).values elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "last": graph_features = edge_features[:, :, -self.layer_layout[-1] :].mean( dim=(1, 2) ) return graph_features