Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch_geometric.nn.aggr import ( | |
AttentionalAggregation, | |
GraphMultisetTransformer, | |
MaxAggregation, | |
MeanAggregation, | |
SetTransformerAggregation, | |
) | |
class CatAggregation(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.flatten = nn.Flatten(1, 2) | |
def forward(self, x, index=None): | |
return self.flatten(x) | |
class HeterogeneousAggregator(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
hidden_dim, | |
output_dim, | |
pooling_method, | |
pooling_layer_idx, | |
input_channels, | |
num_classes, | |
): | |
super().__init__() | |
self.pooling_method = pooling_method | |
self.pooling_layer_idx = pooling_layer_idx | |
self.input_channels = input_channels | |
self.num_classes = num_classes | |
if pooling_layer_idx == "all": | |
self._pool_layer_idx_fn = self.get_all_layer_indices | |
elif pooling_layer_idx == "last": | |
self._pool_layer_idx_fn = self.get_last_layer_indices | |
elif isinstance(pooling_layer_idx, int): | |
self._pool_layer_idx_fn = self.get_nth_layer_indices | |
else: | |
raise ValueError(f"Unknown pooling layer index {pooling_layer_idx}") | |
if pooling_method == "mean": | |
self.pool = MeanAggregation() | |
elif pooling_method == "max": | |
self.pool = MaxAggregation() | |
elif pooling_method == "cat": | |
self.pool = CatAggregation() | |
elif pooling_method == "attentional_aggregation": | |
self.pool = AttentionalAggregation( | |
gate_nn=nn.Sequential( | |
nn.Linear(input_dim, hidden_dim), | |
nn.SiLU(), | |
nn.Linear(hidden_dim, 1), | |
), | |
nn=nn.Sequential( | |
nn.Linear(input_dim, hidden_dim), | |
nn.SiLU(), | |
nn.Linear(hidden_dim, output_dim), | |
), | |
) | |
elif pooling_method == "set_transformer": | |
self.pool = SetTransformerAggregation( | |
input_dim, heads=8, num_encoder_blocks=4, num_decoder_blocks=4 | |
) | |
elif pooling_method == "graph_multiset_transformer": | |
self.pool = GraphMultisetTransformer(input_dim, k=8, heads=8) | |
else: | |
raise ValueError(f"Unknown pooling method {pooling_method}") | |
def get_last_layer_indices( | |
self, x, layer_layouts, node_mask=None, return_dense=False | |
): | |
batch_size = x.shape[0] | |
device = x.device | |
# NOTE: node_mask needs to exist in the heterogeneous case only | |
if node_mask is None: | |
node_mask = torch.ones_like(x[..., 0], dtype=torch.bool, device=device) | |
valid_layer_indices = ( | |
torch.arange(node_mask.shape[1], device=device)[None, :] * node_mask | |
) | |
last_layer_indices = valid_layer_indices.topk( | |
k=self.num_classes, dim=1 | |
).values.fliplr() | |
if return_dense: | |
return torch.arange(batch_size, device=device)[:, None], last_layer_indices | |
batch_indices = torch.arange(batch_size, device=device).repeat_interleave( | |
self.num_classes | |
) | |
return batch_indices, last_layer_indices.flatten() | |
def get_nth_layer_indices( | |
self, x, layer_layouts, node_mask=None, return_dense=False | |
): | |
batch_size = x.shape[0] | |
device = x.device | |
cum_layer_layout = [ | |
torch.cumsum(torch.tensor([0] + layer_layout), dim=0) | |
for layer_layout in layer_layouts | |
] | |
layer_sizes = torch.tensor( | |
[layer_layout[self.pooling_layer_idx] for layer_layout in layer_layouts], | |
dtype=torch.long, | |
device=device, | |
) | |
batch_indices = torch.arange(batch_size, device=device).repeat_interleave( | |
layer_sizes | |
) | |
layer_indices = torch.cat( | |
[ | |
torch.arange( | |
layout[self.pooling_layer_idx], | |
layout[self.pooling_layer_idx + 1], | |
device=device, | |
) | |
for layout in cum_layer_layout | |
] | |
) | |
return batch_indices, layer_indices | |
def get_all_layer_indices( | |
self, x, layer_layouts, node_mask=None, return_dense=False | |
): | |
"""Imitate flattening with indexing""" | |
batch_size, num_nodes = x.shape[:2] | |
device = x.device | |
batch_indices = torch.arange(batch_size, device=device).repeat_interleave( | |
num_nodes | |
) | |
layer_indices = torch.arange(num_nodes, device=device).repeat(batch_size) | |
return batch_indices, layer_indices | |
def forward(self, x, layer_layouts, node_mask=None): | |
# NOTE: `cat` only works with `pooling_layer_idx == "last"` | |
return_dense = self.pooling_method == "cat" and self.pooling_layer_idx == "last" | |
batch_indices, layer_indices = self._pool_layer_idx_fn( | |
x, layer_layouts, node_mask=node_mask, return_dense=return_dense | |
) | |
flat_x = x[batch_indices, layer_indices] | |
return self.pool(flat_x, index=batch_indices) | |
class HomogeneousAggregator(nn.Module): | |
def __init__( | |
self, | |
pooling_method, | |
pooling_layer_idx, | |
layer_layout, | |
): | |
super().__init__() | |
self.pooling_method = pooling_method | |
self.pooling_layer_idx = pooling_layer_idx | |
self.layer_layout = layer_layout | |
def forward(self, node_features, edge_features): | |
if self.pooling_method == "mean" and self.pooling_layer_idx == "all": | |
graph_features = node_features.mean(dim=1) | |
elif self.pooling_method == "max" and self.pooling_layer_idx == "all": | |
graph_features = node_features.max(dim=1).values | |
elif self.pooling_method == "mean" and self.pooling_layer_idx == "last": | |
graph_features = node_features[:, -self.layer_layout[-1] :].mean(dim=1) | |
elif self.pooling_method == "cat" and self.pooling_layer_idx == "last": | |
graph_features = node_features[:, -self.layer_layout[-1] :].flatten(1, 2) | |
elif self.pooling_method == "mean" and isinstance(self.pooling_layer_idx, int): | |
graph_features = node_features[ | |
:, | |
self.layer_idx[self.pooling_layer_idx] : self.layer_idx[ | |
self.pooling_layer_idx + 1 | |
], | |
].mean(dim=1) | |
elif self.pooling_method == "cat_mean" and self.pooling_layer_idx == "all": | |
graph_features = torch.cat( | |
[ | |
node_features[:, self.layer_idx[i] : self.layer_idx[i + 1]].mean( | |
dim=1 | |
) | |
for i in range(len(self.layer_layout)) | |
], | |
dim=1, | |
) | |
elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "all": | |
graph_features = edge_features.mean(dim=(1, 2)) | |
elif self.pooling_method == "max_edge" and self.pooling_layer_idx == "all": | |
graph_features = edge_features.flatten(1, 2).max(dim=1).values | |
elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "last": | |
graph_features = edge_features[:, :, -self.layer_layout[-1] :].mean( | |
dim=(1, 2) | |
) | |
return graph_features | |