Spaces:
Running
Running
import torch | |
from torch import nn | |
from einops import rearrange | |
from typing import List | |
import torch_geometric | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.typing import Adj, Size, OptTensor, Tensor | |
from .egnn_pytorch import * | |
# global linear attention | |
class Attention_Sparse(Attention): | |
def __init__(self, dim, heads = 8, dim_head = 64): | |
""" Wraps the attention class to operate with pytorch-geometric inputs. """ | |
super(Attention_Sparse, self).__init__(dim, heads = 8, dim_head = 64) | |
def sparse_forward(self, x, context, batch=None, batch_uniques=None, mask=None): | |
assert batch is not None or batch_uniques is not None, "Batch/(uniques) must be passed for block_sparse_attn" | |
if batch_uniques is None: | |
batch_uniques = torch.unique(batch, return_counts=True) | |
# only one example in batch - do dense - faster | |
if batch_uniques[0].shape[0] == 1: | |
x, context = map(lambda t: rearrange(t, 'h d -> () h d'), (x, context)) | |
return self.forward(x, context, mask=None).squeeze() # get rid of batch dim | |
# multiple examples in batch - do block-sparse by dense loop | |
else: | |
x_list = [] | |
aux_count = 0 | |
for bi,n_idxs in zip(*batch_uniques): | |
x_list.append( | |
self.sparse_forward( | |
x[aux_count:aux_count+n_idxs], | |
context[aux_count:aux_count+n_idxs], | |
batch_uniques = (bi.unsqueeze(-1), n_idxs.unsqueeze(-1)) | |
) | |
) | |
return torch.cat(x_list, dim=0) | |
class GlobalLinearAttention_Sparse(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
heads = 8, | |
dim_head = 64 | |
): | |
super().__init__() | |
self.norm_seq = torch_geometric.nn.norm.LayerNorm(dim) | |
self.norm_queries = torch_geometric.nn.norm.LayerNorm(dim) | |
self.attn1 = Attention_Sparse(dim, heads, dim_head) | |
self.attn2 = Attention_Sparse(dim, heads, dim_head) | |
# can't concat pyg norms with torch sequentials | |
self.ff_norm = torch_geometric.nn.norm.LayerNorm(dim) | |
self.ff = nn.Sequential( | |
nn.Linear(dim, dim * 4), | |
nn.GELU(), | |
nn.Linear(dim * 4, dim) | |
) | |
def forward(self, x, queries, batch=None, batch_uniques=None, mask = None): | |
res_x, res_queries = x, queries | |
x, queries = self.norm_seq(x, batch=batch), self.norm_queries(queries, batch=batch) | |
induced = self.attn1.sparse_forward(queries, x, batch=batch, batch_uniques=batch_uniques, mask = mask) | |
out = self.attn2.sparse_forward(x, induced, batch=batch, batch_uniques=batch_uniques) | |
x = out + res_x | |
queries = induced + res_queries | |
x_norm = self.ff_norm(x, batch=batch) | |
x = self.ff(x_norm) + x_norm | |
return x, queries | |
# define pytorch-geometric equivalents | |
class EGNN_Sparse(MessagePassing): | |
""" Different from the above since it separates the edge assignment | |
from the computation (this allows for great reduction in time and | |
computations when the graph is locally or sparse connected). | |
* aggr: one of ["add", "mean", "max"] | |
""" | |
def __init__( | |
self, | |
feats_dim, | |
pos_dim=3, | |
edge_attr_dim = 0, | |
m_dim = 16, | |
fourier_features = 0, | |
soft_edge = 0, | |
norm_feats = False, | |
norm_coors = False, | |
norm_coors_scale_init = 1e-2, | |
update_feats = True, | |
update_coors = False, | |
dropout = 0., | |
coor_weights_clamp_value = None, | |
aggr = "add", | |
mlp_num = 2, | |
**kwargs | |
): | |
assert aggr in {'add', 'sum', 'max', 'mean'}, 'pool method must be a valid option' | |
assert update_feats or update_coors, 'you must update either features, coordinates, or both' | |
kwargs.setdefault('aggr', aggr) | |
super(EGNN_Sparse, self).__init__(**kwargs) | |
# model params | |
self.fourier_features = fourier_features | |
self.feats_dim = feats_dim | |
self.pos_dim = pos_dim | |
self.m_dim = m_dim | |
self.soft_edge = soft_edge | |
self.norm_feats = norm_feats | |
self.norm_coors = norm_coors | |
self.update_coors = update_coors | |
self.update_feats = update_feats | |
self.coor_weights_clamp_value = None | |
self.mlp_num = mlp_num | |
self.edge_input_dim = (fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2) | |
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
# EDGES | |
if self.mlp_num >2: | |
self.edge_mlp = nn.Sequential( | |
nn.Linear(self.edge_input_dim, self.edge_input_dim * 8), | |
self.dropout, | |
SiLU(), | |
nn.Linear(self.edge_input_dim * 8, self.edge_input_dim * 4), | |
self.dropout, | |
SiLU(), | |
nn.Linear(self.edge_input_dim * 4, self.edge_input_dim * 2), | |
self.dropout, | |
SiLU(), | |
nn.Linear(self.edge_input_dim * 2, m_dim), | |
SiLU(), | |
) if update_feats else None | |
else: | |
self.edge_mlp = nn.Sequential( | |
nn.Linear(self.edge_input_dim, self.edge_input_dim * 2), | |
self.dropout, | |
SiLU(), | |
nn.Linear(self.edge_input_dim * 2, m_dim), | |
SiLU() | |
) | |
self.edge_weight = nn.Sequential(nn.Linear(m_dim, 1), | |
nn.Sigmoid() | |
) if soft_edge else None | |
# NODES - can't do identity in node_norm bc pyg expects 2 inputs, but identity expects 1. | |
self.node_norm = torch_geometric.nn.norm.LayerNorm(feats_dim) if norm_feats else None | |
self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() | |
if self.mlp_num >2: | |
self.node_mlp = nn.Sequential( | |
nn.Linear(feats_dim + m_dim, feats_dim * 8), | |
self.dropout, | |
SiLU(), | |
nn.Linear(feats_dim * 8, feats_dim * 4), | |
self.dropout, | |
SiLU(), | |
nn.Linear(feats_dim * 4, feats_dim * 2), | |
self.dropout, | |
SiLU(), | |
nn.Linear(feats_dim * 2, feats_dim), | |
) if update_feats else None | |
else: | |
self.node_mlp = nn.Sequential( | |
nn.Linear(feats_dim + m_dim, feats_dim * 2), | |
self.dropout, | |
SiLU(), | |
nn.Linear(feats_dim * 2, feats_dim), | |
) if update_feats else None | |
# COORS | |
self.coors_mlp = nn.Sequential( | |
nn.Linear(m_dim, m_dim * 4), | |
self.dropout, | |
SiLU(), | |
nn.Linear(self.m_dim * 4, 1) | |
) if update_coors else None | |
self.apply(self.init_) | |
def init_(self, module): | |
if type(module) in {nn.Linear}: | |
# seems to be needed to keep the network from exploding to NaN with greater depths | |
nn.init.xavier_normal_(module.weight) | |
nn.init.zeros_(module.bias) | |
def forward(self, x: Tensor, edge_index: Adj, | |
edge_attr: OptTensor = None, batch: Adj = None, | |
angle_data: List = None, size: Size = None) -> Tensor: | |
""" Inputs: | |
* x: (n_points, d) where d is pos_dims + feat_dims | |
* edge_index: (2, n_edges) | |
* edge_attr: tensor (n_edges, n_feats) excluding basic distance feats. | |
* batch: (n_points,) long tensor. specifies xloud belonging for each point | |
* angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor. | |
* size: None | |
""" | |
coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:] | |
rel_coors = coors[edge_index[0]] - coors[edge_index[1]] | |
rel_dist = (rel_coors ** 2).sum(dim=-1, keepdim=True) | |
if self.fourier_features > 0: | |
rel_dist = fourier_encode_dist(rel_dist, num_encodings = self.fourier_features) | |
rel_dist = rearrange(rel_dist, 'n () d -> n d') | |
if exists(edge_attr): | |
edge_attr_feats = torch.cat([edge_attr, rel_dist], dim=-1) | |
else: | |
edge_attr_feats = rel_dist | |
hidden_out, coors_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats, | |
coors=coors, rel_coors=rel_coors, | |
batch=batch) | |
return torch.cat([coors_out, hidden_out], dim=-1) | |
def message(self, x_i, x_j, edge_attr) -> Tensor: | |
m_ij = self.edge_mlp(torch.cat([x_i, x_j, edge_attr], dim=-1) ) | |
return m_ij | |
def propagate(self, edge_index: Adj, size: Size = None, **kwargs): | |
"""The initial call to start propagating messages. | |
Args: | |
`edge_index` holds the indices of a general (sparse) | |
assignment matrix of shape :obj:`[N, M]`. | |
size (tuple, optional) if none, the size will be inferred | |
and assumed to be quadratic. | |
**kwargs: Any additional data which is needed to construct and | |
aggregate messages, and to update node embeddings. | |
""" | |
size = self._check_input(edge_index, size) | |
coll_dict = self._collect(self._user_args, edge_index, size, kwargs) | |
msg_kwargs = self.inspector.collect_param_data('message', coll_dict) | |
aggr_kwargs = self.inspector.collect_param_data('aggregate', coll_dict) | |
update_kwargs = self.inspector.collect_param_data('update', coll_dict) | |
# get messages | |
m_ij = self.message(**msg_kwargs) | |
# update coors if specified | |
if self.update_coors: | |
coor_wij = self.coors_mlp(m_ij) | |
# clamp if arg is set | |
if self.coor_weights_clamp_value: | |
coor_weights_clamp_value = self.coor_weights_clamp_value | |
# coor_weights.clamp_(min = -clamp_value, max = clamp_value) | |
# normalize if needed | |
kwargs["rel_coors"] = self.coors_norm(kwargs["rel_coors"]) | |
mhat_i = self.aggregate(coor_wij * kwargs["rel_coors"], **aggr_kwargs) | |
coors_out = kwargs["coors"] + mhat_i | |
else: | |
coors_out = kwargs["coors"] | |
# update feats if specified | |
if self.update_feats: | |
# weight the edges if arg is passed | |
if self.soft_edge: | |
m_ij = m_ij * self.edge_weight(m_ij) | |
m_i = self.aggregate(m_ij, **aggr_kwargs) | |
hidden_feats = self.node_norm(kwargs["x"], kwargs["batch"]) if self.node_norm else kwargs["x"] | |
hidden_out = self.node_mlp( torch.cat([hidden_feats, m_i], dim = -1) ) | |
hidden_out = kwargs["x"] + hidden_out | |
else: | |
hidden_out = kwargs["x"] | |
# return tuple | |
return self.update((hidden_out, coors_out), **update_kwargs) | |
def __repr__(self): | |
dict_print = {} | |
return "E(n)-GNN Layer for Graphs " + str(self.__dict__) | |
class EGNN_Sparse_Network(nn.Module): | |
r"""Sample GNN model architecture that uses the EGNN-Sparse | |
message passing layer to learn over point clouds. | |
Main MPNN layer introduced in https://arxiv.org/abs/2102.09844v1 | |
Inputs will be standard GNN: x, edge_index, edge_attr, batch, ... | |
Args: | |
* n_layers: int. number of MPNN layers | |
* ... : same interpretation as the base layer. | |
* embedding_nums: list. number of unique keys to embedd. for points | |
1 entry per embedding needed. | |
* embedding_dims: list. point - number of dimensions of | |
the resulting embedding. 1 entry per embedding needed. | |
* edge_embedding_nums: list. number of unique keys to embedd. for edges. | |
1 entry per embedding needed. | |
* edge_embedding_dims: list. point - number of dimensions of | |
the resulting embedding. 1 entry per embedding needed. | |
* recalc: int. Recalculate edge feats every `recalc` MPNN layers. 0 for no recalc | |
* verbose: bool. verbosity level. | |
----- | |
Diff with normal layer: one has to do preprocessing before (radius, global token, ...) | |
""" | |
def __init__(self, n_layers, feats_dim, | |
pos_dim = 3, | |
edge_attr_dim = 0, | |
m_dim = 16, | |
fourier_features = 0, | |
soft_edge = 0, | |
embedding_nums=[], | |
embedding_dims=[], | |
edge_embedding_nums=[], | |
edge_embedding_dims=[], | |
update_coors=True, | |
update_feats=True, | |
norm_feats=True, | |
norm_coors=False, | |
norm_coors_scale_init = 1e-2, | |
dropout=0., | |
coor_weights_clamp_value=None, | |
aggr="add", | |
global_linear_attn_every = 0, | |
global_linear_attn_heads = 8, | |
global_linear_attn_dim_head = 64, | |
num_global_tokens = 4, | |
recalc=0 ,): | |
super().__init__() | |
self.n_layers = n_layers | |
# Embeddings? solve here | |
self.embedding_nums = embedding_nums | |
self.embedding_dims = embedding_dims | |
self.emb_layers = nn.ModuleList() | |
self.edge_embedding_nums = edge_embedding_nums | |
self.edge_embedding_dims = edge_embedding_dims | |
self.edge_emb_layers = nn.ModuleList() | |
# instantiate point and edge embedding layers | |
for i in range( len(self.embedding_dims) ): | |
self.emb_layers.append(nn.Embedding(num_embeddings = embedding_nums[i], | |
embedding_dim = embedding_dims[i])) | |
feats_dim += embedding_dims[i] - 1 | |
for i in range( len(self.edge_embedding_dims) ): | |
self.edge_emb_layers.append(nn.Embedding(num_embeddings = edge_embedding_nums[i], | |
embedding_dim = edge_embedding_dims[i])) | |
edge_attr_dim += edge_embedding_dims[i] - 1 | |
# rest | |
self.mpnn_layers = nn.ModuleList() | |
self.feats_dim = feats_dim | |
self.pos_dim = pos_dim | |
self.edge_attr_dim = edge_attr_dim | |
self.m_dim = m_dim | |
self.fourier_features = fourier_features | |
self.soft_edge = soft_edge | |
self.norm_feats = norm_feats | |
self.norm_coors = norm_coors | |
self.norm_coors_scale_init = norm_coors_scale_init | |
self.update_feats = update_feats | |
self.update_coors = update_coors | |
self.dropout = dropout | |
self.coor_weights_clamp_value = coor_weights_clamp_value | |
self.recalc = recalc | |
self.has_global_attn = global_linear_attn_every > 0 | |
self.global_tokens = None | |
self.global_linear_attn_every = global_linear_attn_every | |
if self.has_global_attn: | |
self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, self.feats_dim)) | |
# instantiate layers | |
for i in range(n_layers): | |
layer = EGNN_Sparse(feats_dim = feats_dim, | |
pos_dim = pos_dim, | |
edge_attr_dim = edge_attr_dim, | |
m_dim = m_dim, | |
fourier_features = fourier_features, | |
soft_edge = soft_edge, | |
norm_feats = norm_feats, | |
norm_coors = norm_coors, | |
norm_coors_scale_init = norm_coors_scale_init, | |
update_feats = update_feats, | |
update_coors = update_coors, | |
dropout = dropout, | |
coor_weights_clamp_value = coor_weights_clamp_value) | |
# global attention case | |
is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0 | |
if is_global_layer: | |
attn_layer = GlobalLinearAttention_Sparse(dim = self.feats_dim, | |
heads = global_linear_attn_heads, | |
dim_head = global_linear_attn_dim_head) | |
self.mpnn_layers.append(nn.ModuleList([attn_layer,layer])) | |
# normal case | |
else: | |
self.mpnn_layers.append(layer) | |
def forward(self, x, edge_index, batch, edge_attr, | |
bsize=None, recalc_edge=None, verbose=0): | |
""" Recalculate edge features every `self.recalc_edge` with the | |
`recalc_edge` function if self.recalc_edge is set. | |
* x: (N, pos_dim+feats_dim) will be unpacked into coors, feats. | |
""" | |
# NODES - Embedd each dim to its target dimensions: | |
x = embedd_token(x, self.embedding_dims, self.emb_layers) | |
# regulates wether to embedd edges each layer | |
edges_need_embedding = False | |
for i,layer in enumerate(self.mpnn_layers): | |
# EDGES - Embedd each dim to its target dimensions: | |
if edges_need_embedding: | |
edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers) | |
edges_need_embedding = False | |
# attn tokens | |
self.global_tokens = None | |
if exists(self.global_tokens): | |
unique, amounts = torch.unique(batch, return_counts=True) | |
num_idxs = torch.cat([torch.arange(num_idxs_i,device=self.global_tokens.device) for num_idxs_i in amounts], dim=-1) | |
global_tokens = self.global_tokens[num_idxs] | |
# pass layers | |
is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0 | |
if not is_global_layer: | |
x = layer(x, edge_index, edge_attr, batch=batch, size=bsize) | |
else: | |
# only pass feats to the attn layer | |
# unique, amounts = torch.unique(batch, return_counts=True) | |
x_attn = layer[0](x[:, self.pos_dim:], x[:, self.pos_dim:],batch)[0]#global_tokens | |
# merge attn-ed feats and coords | |
x = torch.cat( (x[:, :self.pos_dim], x_attn), dim=-1) | |
x = layer[-1](x, edge_index, edge_attr, batch=batch, size=bsize) | |
# recalculate edge info - not needed if last layer | |
if self.recalc and ((i%self.recalc == 0) and not (i == len(self.mpnn_layers)-1)) : | |
edge_index, edge_attr, _ = recalc_edge(x) # returns attr, idx, any_other_info | |
edges_need_embedding = True | |
return x | |
def __repr__(self): | |
return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers)) |