import torch from torch import nn, einsum, broadcast_tensors import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange # helper functions def exists(val): return val is not None def safe_div(num, den, eps = 1e-8): res = num.div(den.clamp(min = eps)) res.masked_fill_(den == 0, 0.) return res def batched_index_select(values, indices, dim = 1): value_dims = values.shape[(dim + 1):] values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) indices = indices[(..., *((None,) * len(value_dims)))] indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) value_expand_len = len(indices_shape) - (dim + 1) values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] value_expand_shape = [-1] * len(values.shape) expand_slice = slice(dim, (dim + value_expand_len)) value_expand_shape[expand_slice] = indices.shape[expand_slice] values = values.expand(*value_expand_shape) dim += value_expand_len return values.gather(dim, indices) def fourier_encode_dist(x, num_encodings = 4, include_self = True): x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype) x = x / scales x = torch.cat([x.sin(), x.cos()], dim=-1) x = torch.cat((x, orig_x), dim = -1) if include_self else x return x def embedd_token(x, dims, layers): stop_concat = -len(dims) to_embedd = x[:, stop_concat:].long() for i,emb_layer in enumerate(layers): # the portion corresponding to `to_embedd` part gets dropped x = torch.cat([ x[:, :stop_concat], emb_layer( to_embedd[:, i] ) ], dim=-1) stop_concat = x.shape[-1] return x # swish activation fallback class Swish_(nn.Module): def forward(self, x): return x * x.sigmoid() SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_ # helper classes # this follows the same strategy for normalization as done in SE3 Transformers # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95 class CoorsNorm(nn.Module): def __init__(self, eps = 1e-8, scale_init = 1.): super().__init__() self.eps = eps scale = torch.zeros(1).fill_(scale_init) self.scale = nn.Parameter(scale) def forward(self, coors): norm = coors.norm(dim = -1, keepdim = True) normed_coors = coors / norm.clamp(min = self.eps) return normed_coors * self.scale # global linear attention class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64): super().__init__() inner_dim = heads * dim_head self.heads = heads self.scale = dim_head ** -0.5 self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) def forward(self, x, context, mask = None): h = self.heads q = self.to_q(x) kv = self.to_kv(context).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv)) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale if exists(mask): mask_value = -torch.finfo(dots.dtype).max mask = rearrange(mask, 'b n -> b () () n') dots.masked_fill_(~mask, mask_value) attn = dots.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)', h = h) return self.to_out(out) class GlobalLinearAttention(nn.Module): def __init__( self, *, dim, heads = 8, dim_head = 64 ): super().__init__() self.norm_seq = nn.LayerNorm(dim) self.norm_queries = nn.LayerNorm(dim) self.attn1 = Attention(dim, heads, dim_head) self.attn2 = Attention(dim, heads, dim_head) self.ff = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) ) def forward(self, x, queries, mask = None): res_x, res_queries = x, queries x, queries = self.norm_seq(x), self.norm_queries(queries) induced = self.attn1(queries, x, mask = mask) out = self.attn2(x, induced) x = out + res_x queries = induced + res_queries x = self.ff(x) + x return x, queries # classes class EGNN(nn.Module): def __init__( self, dim, edge_dim = 0, m_dim = 16, fourier_features = 0, num_nearest_neighbors = 0, dropout = 0.0, init_eps = 1e-3, norm_feats = False, norm_coors = False, norm_coors_scale_init = 1e-2, update_feats = True, update_coors = True, only_sparse_neighbors = False, valid_radius = float('inf'), m_pool_method = 'sum', soft_edges = False, coor_weights_clamp_value = None ): super().__init__() assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean' assert update_feats or update_coors, 'you must update either features, coordinates, or both' self.fourier_features = fourier_features edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1 dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.edge_mlp = nn.Sequential( nn.Linear(edge_input_dim, edge_input_dim * 2), dropout, SiLU(), nn.Linear(edge_input_dim * 2, m_dim), SiLU() ) self.edge_gate = nn.Sequential( nn.Linear(m_dim, 1), nn.Sigmoid() ) if soft_edges else None self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity() self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() self.m_pool_method = m_pool_method self.node_mlp = nn.Sequential( nn.Linear(dim + m_dim, dim * 2), dropout, SiLU(), nn.Linear(dim * 2, dim), ) if update_feats else None self.coors_mlp = nn.Sequential( nn.Linear(m_dim, m_dim * 4), dropout, SiLU(), nn.Linear(m_dim * 4, 1) ) if update_coors else None self.num_nearest_neighbors = num_nearest_neighbors self.only_sparse_neighbors = only_sparse_neighbors self.valid_radius = valid_radius self.coor_weights_clamp_value = coor_weights_clamp_value self.init_eps = init_eps self.apply(self.init_) def init_(self, module): if type(module) in {nn.Linear}: # seems to be needed to keep the network from exploding to NaN with greater depths nn.init.normal_(module.weight, std = self.init_eps) def forward(self, feats, coors, edges = None, mask = None, adj_mat = None): b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, self.num_nearest_neighbors, self.valid_radius, self.only_sparse_neighbors if exists(mask): num_nodes = mask.sum(dim = -1) use_nearest = num_nearest > 0 or only_sparse_neighbors rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d') rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True) i = j = n if use_nearest: ranking = rel_dist[..., 0].clone() if exists(mask): rank_mask = mask[:, :, None] * mask[:, None, :] ranking.masked_fill_(~rank_mask, 1e5) if exists(adj_mat): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) if only_sparse_neighbors: num_nearest = int(adj_mat.float().sum(dim = -1).max().item()) valid_radius = 0 self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j') adj_mat = adj_mat.masked_fill(self_mask, False) ranking.masked_fill_(self_mask, -1.) ranking.masked_fill_(adj_mat, 0.) nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False) nbhd_mask = nbhd_ranking <= valid_radius rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2) rel_dist = batched_index_select(rel_dist, nbhd_indices, dim = 2) if exists(edges): edges = batched_index_select(edges, nbhd_indices, dim = 2) j = num_nearest if fourier_features > 0: rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features) rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d') if use_nearest: feats_j = batched_index_select(feats, nbhd_indices, dim = 1) else: feats_j = rearrange(feats, 'b j d -> b () j d') feats_i = rearrange(feats, 'b i d -> b i () d') feats_i, feats_j = broadcast_tensors(feats_i, feats_j) edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1) if exists(edges): edge_input = torch.cat((edge_input, edges), dim = -1) m_ij = self.edge_mlp(edge_input) if exists(self.edge_gate): m_ij = m_ij * self.edge_gate(m_ij) if exists(mask): mask_i = rearrange(mask, 'b i -> b i ()') if use_nearest: mask_j = batched_index_select(mask, nbhd_indices, dim = 1) mask = (mask_i * mask_j) & nbhd_mask else: mask_j = rearrange(mask, 'b j -> b () j') mask = mask_i * mask_j if exists(self.coors_mlp): coor_weights = self.coors_mlp(m_ij) coor_weights = rearrange(coor_weights, 'b i j () -> b i j') rel_coors = self.coors_norm(rel_coors) if exists(mask): coor_weights.masked_fill_(~mask, 0.) if exists(self.coor_weights_clamp_value): clamp_value = self.coor_weights_clamp_value coor_weights.clamp_(min = -clamp_value, max = clamp_value) coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors else: coors_out = coors if exists(self.node_mlp): if exists(mask): m_ij_mask = rearrange(mask, '... -> ... ()') m_ij = m_ij.masked_fill(~m_ij_mask, 0.) if self.m_pool_method == 'mean': if exists(mask): # masked mean mask_sum = m_ij_mask.sum(dim = -2) m_i = safe_div(m_ij.sum(dim = -2), mask_sum) else: m_i = m_ij.mean(dim = -2) elif self.m_pool_method == 'sum': m_i = m_ij.sum(dim = -2) normed_feats = self.node_norm(feats) node_mlp_input = torch.cat((normed_feats, m_i), dim = -1) node_out = self.node_mlp(node_mlp_input) + feats else: node_out = feats return node_out, coors_out class EGNN_Network(nn.Module): def __init__( self, *, depth, dim, num_tokens = None, num_edge_tokens = None, num_positions = None, edge_dim = 0, num_adj_degrees = None, adj_dim = 0, global_linear_attn_every = 0, global_linear_attn_heads = 8, global_linear_attn_dim_head = 64, num_global_tokens = 4, **kwargs ): super().__init__() assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1' self.num_positions = num_positions self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None self.has_edges = edge_dim > 0 self.num_adj_degrees = num_adj_degrees self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None edge_dim = edge_dim if self.has_edges else 0 adj_dim = adj_dim if exists(num_adj_degrees) else 0 has_global_attn = global_linear_attn_every > 0 self.global_tokens = None if has_global_attn: self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim)) self.layers = nn.ModuleList([]) for ind in range(depth): is_global_layer = has_global_attn and (ind % global_linear_attn_every) == 0 self.layers.append(nn.ModuleList([ GlobalLinearAttention(dim = dim, heads = global_linear_attn_heads, dim_head = global_linear_attn_dim_head) if is_global_layer else None, EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs), ])) def forward( self, feats, coors, adj_mat = None, edges = None, mask = None, return_coor_changes = False ): b, device = feats.shape[0], feats.device if exists(self.token_emb): feats = self.token_emb(feats) if exists(self.pos_emb): n = feats.shape[1] assert n <= self.num_positions, f'given sequence length {n} must be less than the number of positions {self.num_positions} set at init' pos_emb = self.pos_emb(torch.arange(n, device = device)) feats += rearrange(pos_emb, 'n d -> () n d') if exists(edges) and exists(self.edge_emb): edges = self.edge_emb(edges) # create N-degrees adjacent matrix from 1st degree connections if exists(self.num_adj_degrees): assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)' if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) adj_indices = adj_mat.clone().long() for ind in range(self.num_adj_degrees - 1): degree = ind + 2 next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() adj_indices.masked_fill_(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() if exists(self.adj_emb): adj_emb = self.adj_emb(adj_indices) edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb # setup global attention global_tokens = None if exists(self.global_tokens): global_tokens = repeat(self.global_tokens, 'n d -> b n d', b = b) # go through layers coor_changes = [coors] for global_attn, egnn in self.layers: if exists(global_attn): feats, global_tokens = global_attn(feats, global_tokens, mask = mask) feats, coors = egnn(feats, coors, adj_mat = adj_mat, edges = edges, mask = mask) coor_changes.append(coors) if return_coor_changes: return feats, coors, coor_changes return feats, coors