Spaces:
Running
Running
import torch | |
from torch import nn, einsum, broadcast_tensors | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
# helper functions | |
def exists(val): | |
return val is not None | |
def safe_div(num, den, eps = 1e-8): | |
res = num.div(den.clamp(min = eps)) | |
res.masked_fill_(den == 0, 0.) | |
return res | |
def batched_index_select(values, indices, dim = 1): | |
value_dims = values.shape[(dim + 1):] | |
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) | |
indices = indices[(..., *((None,) * len(value_dims)))] | |
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) | |
value_expand_len = len(indices_shape) - (dim + 1) | |
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] | |
value_expand_shape = [-1] * len(values.shape) | |
expand_slice = slice(dim, (dim + value_expand_len)) | |
value_expand_shape[expand_slice] = indices.shape[expand_slice] | |
values = values.expand(*value_expand_shape) | |
dim += value_expand_len | |
return values.gather(dim, indices) | |
def fourier_encode_dist(x, num_encodings = 4, include_self = True): | |
x = x.unsqueeze(-1) | |
device, dtype, orig_x = x.device, x.dtype, x | |
scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype) | |
x = x / scales | |
x = torch.cat([x.sin(), x.cos()], dim=-1) | |
x = torch.cat((x, orig_x), dim = -1) if include_self else x | |
return x | |
def embedd_token(x, dims, layers): | |
stop_concat = -len(dims) | |
to_embedd = x[:, stop_concat:].long() | |
for i,emb_layer in enumerate(layers): | |
# the portion corresponding to `to_embedd` part gets dropped | |
x = torch.cat([ x[:, :stop_concat], | |
emb_layer( to_embedd[:, i] ) | |
], dim=-1) | |
stop_concat = x.shape[-1] | |
return x | |
# swish activation fallback | |
class Swish_(nn.Module): | |
def forward(self, x): | |
return x * x.sigmoid() | |
SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_ | |
# helper classes | |
# this follows the same strategy for normalization as done in SE3 Transformers | |
# https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95 | |
class CoorsNorm(nn.Module): | |
def __init__(self, eps = 1e-8, scale_init = 1.): | |
super().__init__() | |
self.eps = eps | |
scale = torch.zeros(1).fill_(scale_init) | |
self.scale = nn.Parameter(scale) | |
def forward(self, coors): | |
norm = coors.norm(dim = -1, keepdim = True) | |
normed_coors = coors / norm.clamp(min = self.eps) | |
return normed_coors * self.scale | |
# global linear attention | |
class Attention(nn.Module): | |
def __init__(self, dim, heads = 8, dim_head = 64): | |
super().__init__() | |
inner_dim = heads * dim_head | |
self.heads = heads | |
self.scale = dim_head ** -0.5 | |
self.to_q = nn.Linear(dim, inner_dim, bias = False) | |
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | |
self.to_out = nn.Linear(inner_dim, dim) | |
def forward(self, x, context, mask = None): | |
h = self.heads | |
q = self.to_q(x) | |
kv = self.to_kv(context).chunk(2, dim = -1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv)) | |
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
if exists(mask): | |
mask_value = -torch.finfo(dots.dtype).max | |
mask = rearrange(mask, 'b n -> b () () n') | |
dots.masked_fill_(~mask, mask_value) | |
attn = dots.softmax(dim = -1) | |
out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
out = rearrange(out, 'b h n d -> b n (h d)', h = h) | |
return self.to_out(out) | |
class GlobalLinearAttention(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
heads = 8, | |
dim_head = 64 | |
): | |
super().__init__() | |
self.norm_seq = nn.LayerNorm(dim) | |
self.norm_queries = nn.LayerNorm(dim) | |
self.attn1 = Attention(dim, heads, dim_head) | |
self.attn2 = Attention(dim, heads, dim_head) | |
self.ff = nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, dim * 4), | |
nn.GELU(), | |
nn.Linear(dim * 4, dim) | |
) | |
def forward(self, x, queries, mask = None): | |
res_x, res_queries = x, queries | |
x, queries = self.norm_seq(x), self.norm_queries(queries) | |
induced = self.attn1(queries, x, mask = mask) | |
out = self.attn2(x, induced) | |
x = out + res_x | |
queries = induced + res_queries | |
x = self.ff(x) + x | |
return x, queries | |
# classes | |
class EGNN(nn.Module): | |
def __init__( | |
self, | |
dim, | |
edge_dim = 0, | |
m_dim = 16, | |
fourier_features = 0, | |
num_nearest_neighbors = 0, | |
dropout = 0.0, | |
init_eps = 1e-3, | |
norm_feats = False, | |
norm_coors = False, | |
norm_coors_scale_init = 1e-2, | |
update_feats = True, | |
update_coors = True, | |
only_sparse_neighbors = False, | |
valid_radius = float('inf'), | |
m_pool_method = 'sum', | |
soft_edges = False, | |
coor_weights_clamp_value = None | |
): | |
super().__init__() | |
assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean' | |
assert update_feats or update_coors, 'you must update either features, coordinates, or both' | |
self.fourier_features = fourier_features | |
edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1 | |
dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
self.edge_mlp = nn.Sequential( | |
nn.Linear(edge_input_dim, edge_input_dim * 2), | |
dropout, | |
SiLU(), | |
nn.Linear(edge_input_dim * 2, m_dim), | |
SiLU() | |
) | |
self.edge_gate = nn.Sequential( | |
nn.Linear(m_dim, 1), | |
nn.Sigmoid() | |
) if soft_edges else None | |
self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity() | |
self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() | |
self.m_pool_method = m_pool_method | |
self.node_mlp = nn.Sequential( | |
nn.Linear(dim + m_dim, dim * 2), | |
dropout, | |
SiLU(), | |
nn.Linear(dim * 2, dim), | |
) if update_feats else None | |
self.coors_mlp = nn.Sequential( | |
nn.Linear(m_dim, m_dim * 4), | |
dropout, | |
SiLU(), | |
nn.Linear(m_dim * 4, 1) | |
) if update_coors else None | |
self.num_nearest_neighbors = num_nearest_neighbors | |
self.only_sparse_neighbors = only_sparse_neighbors | |
self.valid_radius = valid_radius | |
self.coor_weights_clamp_value = coor_weights_clamp_value | |
self.init_eps = init_eps | |
self.apply(self.init_) | |
def init_(self, module): | |
if type(module) in {nn.Linear}: | |
# seems to be needed to keep the network from exploding to NaN with greater depths | |
nn.init.normal_(module.weight, std = self.init_eps) | |
def forward(self, feats, coors, edges = None, mask = None, adj_mat = None): | |
b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, self.num_nearest_neighbors, self.valid_radius, self.only_sparse_neighbors | |
if exists(mask): | |
num_nodes = mask.sum(dim = -1) | |
use_nearest = num_nearest > 0 or only_sparse_neighbors | |
rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d') | |
rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True) | |
i = j = n | |
if use_nearest: | |
ranking = rel_dist[..., 0].clone() | |
if exists(mask): | |
rank_mask = mask[:, :, None] * mask[:, None, :] | |
ranking.masked_fill_(~rank_mask, 1e5) | |
if exists(adj_mat): | |
if len(adj_mat.shape) == 2: | |
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) | |
if only_sparse_neighbors: | |
num_nearest = int(adj_mat.float().sum(dim = -1).max().item()) | |
valid_radius = 0 | |
self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j') | |
adj_mat = adj_mat.masked_fill(self_mask, False) | |
ranking.masked_fill_(self_mask, -1.) | |
ranking.masked_fill_(adj_mat, 0.) | |
nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False) | |
nbhd_mask = nbhd_ranking <= valid_radius | |
rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2) | |
rel_dist = batched_index_select(rel_dist, nbhd_indices, dim = 2) | |
if exists(edges): | |
edges = batched_index_select(edges, nbhd_indices, dim = 2) | |
j = num_nearest | |
if fourier_features > 0: | |
rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features) | |
rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d') | |
if use_nearest: | |
feats_j = batched_index_select(feats, nbhd_indices, dim = 1) | |
else: | |
feats_j = rearrange(feats, 'b j d -> b () j d') | |
feats_i = rearrange(feats, 'b i d -> b i () d') | |
feats_i, feats_j = broadcast_tensors(feats_i, feats_j) | |
edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1) | |
if exists(edges): | |
edge_input = torch.cat((edge_input, edges), dim = -1) | |
m_ij = self.edge_mlp(edge_input) | |
if exists(self.edge_gate): | |
m_ij = m_ij * self.edge_gate(m_ij) | |
if exists(mask): | |
mask_i = rearrange(mask, 'b i -> b i ()') | |
if use_nearest: | |
mask_j = batched_index_select(mask, nbhd_indices, dim = 1) | |
mask = (mask_i * mask_j) & nbhd_mask | |
else: | |
mask_j = rearrange(mask, 'b j -> b () j') | |
mask = mask_i * mask_j | |
if exists(self.coors_mlp): | |
coor_weights = self.coors_mlp(m_ij) | |
coor_weights = rearrange(coor_weights, 'b i j () -> b i j') | |
rel_coors = self.coors_norm(rel_coors) | |
if exists(mask): | |
coor_weights.masked_fill_(~mask, 0.) | |
if exists(self.coor_weights_clamp_value): | |
clamp_value = self.coor_weights_clamp_value | |
coor_weights.clamp_(min = -clamp_value, max = clamp_value) | |
coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors | |
else: | |
coors_out = coors | |
if exists(self.node_mlp): | |
if exists(mask): | |
m_ij_mask = rearrange(mask, '... -> ... ()') | |
m_ij = m_ij.masked_fill(~m_ij_mask, 0.) | |
if self.m_pool_method == 'mean': | |
if exists(mask): | |
# masked mean | |
mask_sum = m_ij_mask.sum(dim = -2) | |
m_i = safe_div(m_ij.sum(dim = -2), mask_sum) | |
else: | |
m_i = m_ij.mean(dim = -2) | |
elif self.m_pool_method == 'sum': | |
m_i = m_ij.sum(dim = -2) | |
normed_feats = self.node_norm(feats) | |
node_mlp_input = torch.cat((normed_feats, m_i), dim = -1) | |
node_out = self.node_mlp(node_mlp_input) + feats | |
else: | |
node_out = feats | |
return node_out, coors_out | |
class EGNN_Network(nn.Module): | |
def __init__( | |
self, | |
*, | |
depth, | |
dim, | |
num_tokens = None, | |
num_edge_tokens = None, | |
num_positions = None, | |
edge_dim = 0, | |
num_adj_degrees = None, | |
adj_dim = 0, | |
global_linear_attn_every = 0, | |
global_linear_attn_heads = 8, | |
global_linear_attn_dim_head = 64, | |
num_global_tokens = 4, | |
**kwargs | |
): | |
super().__init__() | |
assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1' | |
self.num_positions = num_positions | |
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None | |
self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None | |
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None | |
self.has_edges = edge_dim > 0 | |
self.num_adj_degrees = num_adj_degrees | |
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None | |
edge_dim = edge_dim if self.has_edges else 0 | |
adj_dim = adj_dim if exists(num_adj_degrees) else 0 | |
has_global_attn = global_linear_attn_every > 0 | |
self.global_tokens = None | |
if has_global_attn: | |
self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim)) | |
self.layers = nn.ModuleList([]) | |
for ind in range(depth): | |
is_global_layer = has_global_attn and (ind % global_linear_attn_every) == 0 | |
self.layers.append(nn.ModuleList([ | |
GlobalLinearAttention(dim = dim, heads = global_linear_attn_heads, dim_head = global_linear_attn_dim_head) if is_global_layer else None, | |
EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs), | |
])) | |
def forward( | |
self, | |
feats, | |
coors, | |
adj_mat = None, | |
edges = None, | |
mask = None, | |
return_coor_changes = False | |
): | |
b, device = feats.shape[0], feats.device | |
if exists(self.token_emb): | |
feats = self.token_emb(feats) | |
if exists(self.pos_emb): | |
n = feats.shape[1] | |
assert n <= self.num_positions, f'given sequence length {n} must be less than the number of positions {self.num_positions} set at init' | |
pos_emb = self.pos_emb(torch.arange(n, device = device)) | |
feats += rearrange(pos_emb, 'n d -> () n d') | |
if exists(edges) and exists(self.edge_emb): | |
edges = self.edge_emb(edges) | |
# create N-degrees adjacent matrix from 1st degree connections | |
if exists(self.num_adj_degrees): | |
assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)' | |
if len(adj_mat.shape) == 2: | |
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) | |
adj_indices = adj_mat.clone().long() | |
for ind in range(self.num_adj_degrees - 1): | |
degree = ind + 2 | |
next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 | |
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() | |
adj_indices.masked_fill_(next_degree_mask, degree) | |
adj_mat = next_degree_adj_mat.clone() | |
if exists(self.adj_emb): | |
adj_emb = self.adj_emb(adj_indices) | |
edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb | |
# setup global attention | |
global_tokens = None | |
if exists(self.global_tokens): | |
global_tokens = repeat(self.global_tokens, 'n d -> b n d', b = b) | |
# go through layers | |
coor_changes = [coors] | |
for global_attn, egnn in self.layers: | |
if exists(global_attn): | |
feats, global_tokens = global_attn(feats, global_tokens, mask = mask) | |
feats, coors = egnn(feats, coors, adj_mat = adj_mat, edges = edges, mask = mask) | |
coor_changes.append(coors) | |
if return_coor_changes: | |
return feats, coors, coor_changes | |
return feats, coors | |