Spaces:
Running
Running
""" | |
This code was adapted from https://github.com/sarpaykent/GotenNet | |
Copyright (c) 2025 Sarp Aykent | |
MIT License | |
GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks | |
Sarp Aykent and Tian Xia | |
https://openreview.net/pdf?id=5wxCQDtbMo | |
""" | |
from __future__ import absolute_import, division, print_function | |
import inspect | |
import math | |
from functools import partial | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
from torch import nn as nn | |
from torch.nn.init import constant_, xavier_uniform_ | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn.inits import glorot_orthogonal | |
from torch_geometric.nn.models.schnet import ShiftedSoftplus | |
#from torch_scatter import scatter | |
zeros_initializer = partial(constant_, val=0.0) | |
def centralize( | |
batch, | |
key: str, | |
batch_index: torch.Tensor, | |
): # note: cannot make assumptions on output shape | |
# derive centroid of each batch element, and center entities using corresponding centroids | |
entities_centroid = scatter(batch[key], batch_index, dim=0, reduce="mean") # e.g., [batch_size, 3] | |
entities_centered = batch[key] - entities_centroid[batch_index] | |
return entities_centroid, entities_centered | |
def decentralize( | |
positions: torch.Tensor, | |
batch_index: torch.Tensor, | |
entities_centroid: torch.Tensor, | |
) -> torch.Tensor: # note: cannot make assumptions on output shape | |
entities_centered = positions + entities_centroid[batch_index] | |
return entities_centered | |
def parse_update_info(edge_updates): | |
update_info = { | |
"gated": False, | |
"rej": True, | |
"vec_norm": False, | |
"mlp": False, | |
"mlpa": False, | |
"lin_w": 0, | |
"drej": False, | |
} | |
if isinstance(edge_updates, str): | |
update_parts = edge_updates.split("_") | |
else: | |
update_parts = [] | |
allowed_parts = ["gated", "gatedt", "norej", "mlp", "mlpa", "act", "linw", "linwa", "drej"] | |
if not all([part in allowed_parts for part in update_parts]): | |
raise ValueError(f"Invalid edge update parts. Allowed parts are {allowed_parts}") | |
if "gated" in update_parts: | |
update_info["gated"] = "gated" | |
if "gatedt" in update_parts: | |
update_info["gated"] = "gatedt" | |
if "act" in update_parts: | |
update_info["gated"] = "act" | |
if "norej" in update_parts: | |
update_info["rej"] = False | |
if "mlp" in update_parts: | |
update_info["mlp"] = True | |
if "mlpa" in update_parts: | |
update_info["mlpa"] = True | |
if "linw" in update_parts: | |
update_info["lin_w"] = 1 | |
if "linwa" in update_parts: | |
update_info["lin_w"] = 2 | |
if "drej" in update_parts: | |
update_info["drej"] = True | |
return update_info | |
class SmoothLeakyReLU(torch.nn.Module): | |
def __init__(self, negative_slope=0.2): | |
super().__init__() | |
self.alpha = negative_slope | |
def forward(self, x): | |
x1 = ((1 + self.alpha) / 2) * x | |
x2 = ((1 - self.alpha) / 2) * x * (2 * torch.sigmoid(x) - 1) | |
return x1 + x2 | |
def extra_repr(self): | |
return "negative_slope={}".format(self.alpha) | |
def shifted_softplus(x: torch.Tensor): | |
return F.softplus(x) - math.log(2.0) | |
class PolynomialCutoff(nn.Module): | |
def __init__(self, cutoff, p: int = 6): | |
super(PolynomialCutoff, self).__init__() | |
self.cutoff = cutoff | |
self.p = p | |
def polynomial_cutoff(r: Tensor, rcut: float, p: float = 6.0) -> Tensor: | |
""" | |
Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123 | |
""" | |
if not p >= 2.0: | |
# replace below with logger error | |
print(f"Exponent p={p} has to be >= 2.") | |
print("Exiting code.") | |
print(f"Exponent p={p} has to be >= 2.") | |
print("Exiting code.") | |
exit() | |
rscaled = r / rcut | |
out = 1.0 | |
out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(rscaled, p)) | |
out = out + (p * (p + 2.0) * torch.pow(rscaled, p + 1.0)) | |
out = out - ((p * (p + 1.0) / 2) * torch.pow(rscaled, p + 2.0)) | |
return out * (rscaled < 1.0).float() | |
def forward(self, r): | |
return self.polynomial_cutoff(r=r, rcut=self.cutoff, p=self.p) | |
def __repr__(self): | |
return f"{self.__class__.__name__}(cutoff={self.cutoff}, p={self.p})" | |
class CosineCutoff(nn.Module): | |
def __init__(self, cutoff): | |
super(CosineCutoff, self).__init__() | |
if isinstance(cutoff, torch.Tensor): | |
cutoff = cutoff.item() | |
self.cutoff = cutoff | |
def forward(self, distances): | |
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0) | |
cutoffs = cutoffs * (distances < self.cutoff).float() | |
return cutoffs | |
class ScaleShift(nn.Module): | |
r"""Scale and shift layer for standardization. | |
.. math:: | |
y = x \times \sigma + \mu | |
Args: | |
mean (torch.Tensor): mean value :math:`\mu`. | |
stddev (torch.Tensor): standard deviation value :math:`\sigma`. | |
""" | |
def __init__(self, mean, stddev): | |
super(ScaleShift, self).__init__() | |
if isinstance(mean, float): | |
mean = torch.FloatTensor([mean]) | |
if isinstance(stddev, float): | |
stddev = torch.FloatTensor([stddev]) | |
self.register_buffer("mean", mean) | |
self.register_buffer("stddev", stddev) | |
def forward(self, input): | |
"""Compute layer output. | |
Args: | |
input (torch.Tensor): input data. | |
Returns: | |
torch.Tensor: layer output. | |
""" | |
y = input * self.stddev + self.mean | |
return y | |
class GetItem(nn.Module): | |
"""Extraction layer to get an item from SchNetPack dictionary of input tensors. | |
Args: | |
key (str): Property to be extracted from SchNetPack input tensors. | |
""" | |
def __init__(self, key): | |
super(GetItem, self).__init__() | |
self.key = key | |
def forward(self, inputs): | |
"""Compute layer output. | |
Args: | |
inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors. | |
Returns: | |
torch.Tensor: layer output. | |
""" | |
return inputs[self.key] | |
class SchnetMLP(nn.Module): | |
"""Multiple layer fully connected perceptron neural network. | |
Args: | |
n_in (int): number of input nodes. | |
n_out (int): number of output nodes. | |
n_hidden (list of int or int, optional): number hidden layer nodes. | |
If an integer, same number of node is used for all hidden layers resulting | |
in a rectangular network. | |
If None, the number of neurons is divided by two after each layer starting | |
n_in resulting in a pyramidal network. | |
n_layers (int, optional): number of layers. | |
activation (callable, optional): activation function. All hidden layers would | |
the same activation function except the output layer that does not apply | |
any activation function. | |
""" | |
def __init__(self, n_in, n_out, n_hidden=None, n_layers=2, activation=shifted_softplus): | |
super(SchnetMLP, self).__init__() | |
# get list of number of nodes in input, hidden & output layers | |
if n_hidden is None: | |
c_neurons = n_in | |
self.n_neurons = [] | |
for i in range(n_layers): | |
self.n_neurons.append(c_neurons) | |
c_neurons = c_neurons // 2 | |
self.n_neurons.append(n_out) | |
else: | |
# get list of number of nodes hidden layers | |
if type(n_hidden) is int: | |
n_hidden = [n_hidden] * (n_layers - 1) | |
self.n_neurons = [n_in] + n_hidden + [n_out] | |
# assign a Dense layer (with activation function) to each hidden layer | |
layers = [Dense(self.n_neurons[i], self.n_neurons[i + 1], activation=activation) for i in range(n_layers - 1)] | |
# assign a Dense layer (without activation function) to the output layer | |
layers.append(Dense(self.n_neurons[-2], self.n_neurons[-1], activation=None)) | |
# put all layers together to make the network | |
self.out_net = nn.Sequential(*layers) | |
def forward(self, inputs): | |
"""Compute neural network output. | |
Args: | |
inputs (torch.Tensor): network input. | |
Returns: | |
torch.Tensor: network output. | |
""" | |
return self.out_net(inputs) | |
def scaled_silu(x, scale=0.6): | |
return F.silu(x) * scale | |
def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor): | |
coeff = -0.5 / torch.pow(widths, 2) | |
diff = inputs[..., None] - offsets | |
y = torch.exp(coeff * torch.pow(diff, 2)) | |
return y | |
class GaussianRBF(nn.Module): | |
r"""Gaussian radial basis functions.""" | |
def __init__(self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False): | |
""" | |
Args: | |
n_rbf: total number of Gaussian functions, :math:`N_g`. | |
cutoff: center of last Gaussian function, :math:`\mu_{N_g}` | |
start: center of first Gaussian function, :math:`\mu_0`. | |
trainable: If True, widths and offset of Gaussian functions | |
are adjusted during training process. | |
""" | |
super(GaussianRBF, self).__init__() | |
self.n_rbf = n_rbf | |
# compute offset and width of Gaussian functions | |
offset = torch.linspace(start, cutoff, n_rbf) | |
widths = torch.FloatTensor(torch.abs(offset[1] - offset[0]) * torch.ones_like(offset)) | |
if trainable: | |
self.widths = nn.Parameter(widths) | |
self.offsets = nn.Parameter(offset) | |
else: | |
self.register_buffer("widths", widths) | |
self.register_buffer("offsets", offset) | |
def forward(self, inputs: torch.Tensor): | |
return gaussian_rbf(inputs, self.offsets, self.widths) | |
class BesselBasis(nn.Module): | |
""" | |
Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet) | |
""" | |
def __init__(self, cutoff=5.0, n_rbf=None, trainable=False): | |
""" | |
Args: | |
cutoff: radial cutoff | |
n_rbf: number of basis functions. | |
""" | |
super(BesselBasis, self).__init__() | |
self.n_rbf = n_rbf | |
# compute offset and width of Gaussian functions | |
freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff | |
self.register_buffer("freqs", freqs) | |
self.register_buffer("norm1", torch.tensor(1.0)) | |
def forward(self, inputs): | |
input_size = len(inputs.shape) # noqa: F841 | |
a = self.freqs[None, :] | |
inputs = inputs[..., None] | |
ax = inputs * a | |
sinax = torch.sin(ax) | |
norm = torch.where(inputs == 0, self.norm1, inputs) | |
y = sinax / norm | |
return y | |
def glorot_orthogonal_wrapper_(tensor, scale=2.0): | |
return glorot_orthogonal(tensor, scale=scale) | |
def _standardize(kernel): | |
""" | |
Makes sure that Var(W) = 1 and E[W] = 0 | |
""" | |
eps = 1e-6 | |
if len(kernel.shape) == 3: | |
axis = [0, 1] # last dimension is output dimension | |
else: | |
axis = 1 | |
var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) | |
kernel = (kernel - mean) / (var + eps) ** 0.5 | |
return kernel | |
def he_orthogonal_init(tensor): | |
""" | |
Generate a weight matrix with variance according to He initialization. | |
Based on a random (semi-)orthogonal matrix neural networks | |
are expected to learn better when features are decorrelated | |
(stated by eg. "Reducing overfitting in deep networks by decorrelating representations", | |
"Dropout: a simple way to prevent neural networks from overfitting", | |
"Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") | |
""" | |
tensor = torch.nn.init.orthogonal_(tensor) | |
if len(tensor.shape) == 3: | |
fan_in = tensor.shape[:-1].numel() | |
else: | |
fan_in = tensor.shape[1] | |
with torch.no_grad(): | |
tensor.data = _standardize(tensor.data) | |
tensor.data *= (1 / fan_in) ** 0.5 | |
return tensor | |
def get_weight_init_by_string(init_str): | |
if init_str == "": | |
# Noop | |
return lambda x: x | |
elif init_str == "zeros": | |
return torch.nn.init.zeros_ | |
elif init_str == "xavier_uniform": | |
return torch.nn.init.xavier_uniform_ | |
elif init_str == "glo_orthogonal": | |
return glorot_orthogonal_wrapper_ | |
elif init_str == "he_orthogonal": | |
return he_orthogonal_init | |
else: | |
raise ValueError(f"Unknown initialization {init_str}") | |
class Dense(nn.Linear): | |
r"""Fully connected linear layer with activation function. | |
Barrowed from https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/nn/base.py | |
.. math:: | |
y = activation(xW^T + b) | |
Args: | |
in_features (int): number of input feature :math:`x`. | |
out_features (int): number of output features :math:`y`. | |
bias (bool, optional): if False, the layer will not adapt bias :math:`b`. | |
activation (callable, optional): if None, no activation function is used. | |
weight_init (callable, optional): weight initializer from current weight. | |
bias_init (callable, optional): bias initializer from current bias. | |
""" | |
def __init__( | |
self, | |
in_features, | |
out_features, | |
bias=True, | |
activation=None, | |
weight_init=xavier_uniform_, | |
bias_init=zeros_initializer, | |
norm=None, | |
gain=None, | |
): | |
# initialize linear layer y = xW^T + b | |
self.weight_init = weight_init | |
self.bias_init = bias_init | |
self.gain = gain | |
super(Dense, self).__init__(in_features, out_features, bias) | |
# Initialize activation function | |
if inspect.isclass(activation): | |
self.activation = activation() | |
self.activation = activation | |
if norm == "layer": | |
self.norm = nn.LayerNorm(out_features) | |
elif norm == "batch": | |
self.norm = nn.BatchNorm1d(out_features) | |
elif norm == "instance": | |
self.norm = nn.InstanceNorm1d(out_features) | |
else: | |
self.norm = None | |
def reset_parameters(self): | |
"""Reinitialize model weight and bias values.""" | |
if self.gain: | |
self.weight_init(self.weight, gain=self.gain) | |
else: | |
self.weight_init(self.weight) | |
if self.bias is not None: | |
self.bias_init(self.bias) | |
def forward(self, inputs): | |
"""Compute layer output. | |
Args: | |
inputs (dict of torch.Tensor): batch of input values. | |
Returns: | |
torch.Tensor: layer output. | |
""" | |
# compute linear layer y = xW^T + b | |
y = super(Dense, self).forward(inputs) | |
if self.norm is not None: | |
y = self.norm(y) | |
# add activation function | |
if self.activation: | |
y = self.activation(y) | |
return y | |
class _VDropout(nn.Module): | |
""" | |
Vector channel dropout where the elements of each | |
vector channel are dropped together. | |
""" | |
def __init__(self, drop_rate, scale=True): | |
super(_VDropout, self).__init__() | |
self.drop_rate = drop_rate | |
self.scale = scale | |
def forward(self, x, dim=-1): | |
""" | |
:param x: `torch.Tensor` corresponding to vector channels | |
""" | |
if self.drop_rate == 0: | |
return x | |
device = x.device | |
if not self.training: | |
return x | |
shape = list(x.shape) | |
assert shape[dim] == 3, "The dimension must be vector" | |
shape[dim] = 1 | |
mask = torch.bernoulli((1 - self.drop_rate) * torch.ones(shape, device=device)) | |
x = mask * x | |
if self.scale: | |
# scale the output to keep the expected output distribution | |
# same as input distribution. However, this might be harmfuk | |
# for vector space. | |
x = x / (1 - self.drop_rate) | |
return x | |
class Dropout(nn.Module): | |
""" | |
Combined dropout for tuples (s, V). | |
Takes tuples (s, V) as input and as output. | |
""" | |
def __init__(self, drop_rate, vector_dropout=True): | |
super(Dropout, self).__init__() | |
self.sdropout = nn.Dropout(drop_rate) | |
if vector_dropout: | |
self.vdropout = _VDropout(drop_rate) | |
else: | |
self.vdropout = lambda x, dim: x | |
def forward(self, x): | |
""" | |
:param x: tuple (s, V) of `torch.Tensor`, | |
or single `torch.Tensor` | |
(will be assumed to be scalar channels) | |
""" | |
if type(x) is torch.Tensor: | |
return self.sdropout(x) | |
s, v = x | |
return self.sdropout(s), self.vdropout(v, dim=1) | |
class TensorInit(nn.Module): | |
def __init__(self, l=2): # noqa: E741 | |
super(TensorInit, self).__init__() | |
self.l = l | |
def forward(self, edge_vec): | |
edge_sh = self._calculate_components(self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2]) | |
return edge_sh | |
def tensor_size(self): | |
return ((self.l + 1) ** 2) - 1 | |
def _calculate_components(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
sh_1_0, sh_1_1, sh_1_2 = x, y, z | |
if lmax == 1: | |
return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1) | |
# (x^2, y^2, z^2) ^2 | |
sh_2_0 = math.sqrt(3.0) * x * z | |
sh_2_1 = math.sqrt(3.0) * x * y | |
y2 = y.pow(2) | |
x2z2 = x.pow(2) + z.pow(2) | |
sh_2_2 = y2 - 0.5 * x2z2 | |
sh_2_3 = math.sqrt(3.0) * y * z | |
sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) | |
if lmax == 2: | |
return torch.stack([sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1) | |
# Borrowed from e3nn: https://github.com/e3nn/e3nn/blob/main/e3nn/o3/_spherical_harmonics.py#L188 | |
sh_3_0 = (1 / 6) * math.sqrt(42) * (sh_2_0 * z + sh_2_4 * x) | |
sh_3_1 = math.sqrt(7) * sh_2_0 * y | |
sh_3_2 = (1 / 8) * math.sqrt(168) * (4.0 * y2 - x2z2) * x | |
sh_3_3 = (1 / 2) * math.sqrt(7) * y * (2.0 * y2 - 3.0 * x2z2) | |
sh_3_4 = (1 / 8) * math.sqrt(168) * z * (4.0 * y2 - x2z2) | |
sh_3_5 = math.sqrt(7) * sh_2_4 * y | |
sh_3_6 = (1 / 6) * math.sqrt(42) * (sh_2_4 * z - sh_2_0 * x) | |
if lmax == 3: | |
return torch.stack( | |
[ | |
sh_1_0, | |
sh_1_1, | |
sh_1_2, | |
sh_2_0, | |
sh_2_1, | |
sh_2_2, | |
sh_2_3, | |
sh_2_4, | |
sh_3_0, | |
sh_3_1, | |
sh_3_2, | |
sh_3_3, | |
sh_3_4, | |
sh_3_5, | |
sh_3_6, | |
], | |
dim=-1, | |
) | |
sh_4_0 = (3 / 4) * math.sqrt(2) * (sh_3_0 * z + sh_3_6 * x) | |
sh_4_1 = (3 / 4) * sh_3_0 * y + (3 / 8) * math.sqrt(6) * sh_3_1 * z + (3 / 8) * math.sqrt(6) * sh_3_5 * x | |
sh_4_2 = ( | |
-3 / 56 * math.sqrt(14) * sh_3_0 * z | |
+ (3 / 14) * math.sqrt(21) * sh_3_1 * y | |
+ (3 / 56) * math.sqrt(210) * sh_3_2 * z | |
+ (3 / 56) * math.sqrt(210) * sh_3_4 * x | |
+ (3 / 56) * math.sqrt(14) * sh_3_6 * x | |
) | |
sh_4_3 = ( | |
-3 / 56 * math.sqrt(42) * sh_3_1 * z | |
+ (3 / 28) * math.sqrt(105) * sh_3_2 * y | |
+ (3 / 28) * math.sqrt(70) * sh_3_3 * x | |
+ (3 / 56) * math.sqrt(42) * sh_3_5 * x | |
) | |
sh_4_4 = -3 / 28 * math.sqrt(42) * sh_3_2 * x + (3 / 7) * math.sqrt(7) * sh_3_3 * y - 3 / 28 * math.sqrt(42) * sh_3_4 * z | |
sh_4_5 = ( | |
-3 / 56 * math.sqrt(42) * sh_3_1 * x | |
+ (3 / 28) * math.sqrt(70) * sh_3_3 * z | |
+ (3 / 28) * math.sqrt(105) * sh_3_4 * y | |
- 3 / 56 * math.sqrt(42) * sh_3_5 * z | |
) | |
sh_4_6 = ( | |
-3 / 56 * math.sqrt(14) * sh_3_0 * x | |
- 3 / 56 * math.sqrt(210) * sh_3_2 * x | |
+ (3 / 56) * math.sqrt(210) * sh_3_4 * z | |
+ (3 / 14) * math.sqrt(21) * sh_3_5 * y | |
- 3 / 56 * math.sqrt(14) * sh_3_6 * z | |
) | |
sh_4_7 = -3 / 8 * math.sqrt(6) * sh_3_1 * x + (3 / 8) * math.sqrt(6) * sh_3_5 * z + (3 / 4) * sh_3_6 * y | |
sh_4_8 = (3 / 4) * math.sqrt(2) * (-sh_3_0 * x + sh_3_6 * z) | |
if lmax == 4: | |
return torch.stack( | |
[ | |
sh_1_0, | |
sh_1_1, | |
sh_1_2, | |
sh_2_0, | |
sh_2_1, | |
sh_2_2, | |
sh_2_3, | |
sh_2_4, | |
sh_3_0, | |
sh_3_1, | |
sh_3_2, | |
sh_3_3, | |
sh_3_4, | |
sh_3_5, | |
sh_3_6, | |
sh_4_0, | |
sh_4_1, | |
sh_4_2, | |
sh_4_3, | |
sh_4_4, | |
sh_4_5, | |
sh_4_6, | |
sh_4_7, | |
sh_4_8, | |
], | |
dim=-1, | |
) | |
sh_5_0 = (1 / 10) * math.sqrt(110) * (sh_4_0 * z + sh_4_8 * x) | |
sh_5_1 = (1 / 5) * math.sqrt(11) * sh_4_0 * y + (1 / 5) * math.sqrt(22) * sh_4_1 * z + (1 / 5) * math.sqrt(22) * sh_4_7 * x | |
sh_5_2 = ( | |
-1 / 30 * math.sqrt(22) * sh_4_0 * z | |
+ (4 / 15) * math.sqrt(11) * sh_4_1 * y | |
+ (1 / 15) * math.sqrt(154) * sh_4_2 * z | |
+ (1 / 15) * math.sqrt(154) * sh_4_6 * x | |
+ (1 / 30) * math.sqrt(22) * sh_4_8 * x | |
) | |
sh_5_3 = ( | |
-1 / 30 * math.sqrt(66) * sh_4_1 * z | |
+ (1 / 15) * math.sqrt(231) * sh_4_2 * y | |
+ (1 / 30) * math.sqrt(462) * sh_4_3 * z | |
+ (1 / 30) * math.sqrt(462) * sh_4_5 * x | |
+ (1 / 30) * math.sqrt(66) * sh_4_7 * x | |
) | |
sh_5_4 = ( | |
-1 / 15 * math.sqrt(33) * sh_4_2 * z | |
+ (2 / 15) * math.sqrt(66) * sh_4_3 * y | |
+ (1 / 15) * math.sqrt(165) * sh_4_4 * x | |
+ (1 / 15) * math.sqrt(33) * sh_4_6 * x | |
) | |
sh_5_5 = -1 / 15 * math.sqrt(110) * sh_4_3 * x + (1 / 3) * math.sqrt(11) * sh_4_4 * y - 1 / 15 * math.sqrt(110) * sh_4_5 * z | |
sh_5_6 = ( | |
-1 / 15 * math.sqrt(33) * sh_4_2 * x | |
+ (1 / 15) * math.sqrt(165) * sh_4_4 * z | |
+ (2 / 15) * math.sqrt(66) * sh_4_5 * y | |
- 1 / 15 * math.sqrt(33) * sh_4_6 * z | |
) | |
sh_5_7 = ( | |
-1 / 30 * math.sqrt(66) * sh_4_1 * x | |
- 1 / 30 * math.sqrt(462) * sh_4_3 * x | |
+ (1 / 30) * math.sqrt(462) * sh_4_5 * z | |
+ (1 / 15) * math.sqrt(231) * sh_4_6 * y | |
- 1 / 30 * math.sqrt(66) * sh_4_7 * z | |
) | |
sh_5_8 = ( | |
-1 / 30 * math.sqrt(22) * sh_4_0 * x | |
- 1 / 15 * math.sqrt(154) * sh_4_2 * x | |
+ (1 / 15) * math.sqrt(154) * sh_4_6 * z | |
+ (4 / 15) * math.sqrt(11) * sh_4_7 * y | |
- 1 / 30 * math.sqrt(22) * sh_4_8 * z | |
) | |
sh_5_9 = -1 / 5 * math.sqrt(22) * sh_4_1 * x + (1 / 5) * math.sqrt(22) * sh_4_7 * z + (1 / 5) * math.sqrt(11) * sh_4_8 * y | |
sh_5_10 = (1 / 10) * math.sqrt(110) * (-sh_4_0 * x + sh_4_8 * z) | |
if lmax == 5: | |
return torch.stack( | |
[ | |
sh_1_0, | |
sh_1_1, | |
sh_1_2, | |
sh_2_0, | |
sh_2_1, | |
sh_2_2, | |
sh_2_3, | |
sh_2_4, | |
sh_3_0, | |
sh_3_1, | |
sh_3_2, | |
sh_3_3, | |
sh_3_4, | |
sh_3_5, | |
sh_3_6, | |
sh_4_0, | |
sh_4_1, | |
sh_4_2, | |
sh_4_3, | |
sh_4_4, | |
sh_4_5, | |
sh_4_6, | |
sh_4_7, | |
sh_4_8, | |
sh_5_0, | |
sh_5_1, | |
sh_5_2, | |
sh_5_3, | |
sh_5_4, | |
sh_5_5, | |
sh_5_6, | |
sh_5_7, | |
sh_5_8, | |
sh_5_9, | |
sh_5_10, | |
], | |
dim=-1, | |
) | |
sh_6_0 = (1 / 6) * math.sqrt(39) * (sh_5_0 * z + sh_5_10 * x) | |
sh_6_1 = (1 / 6) * math.sqrt(13) * sh_5_0 * y + (1 / 12) * math.sqrt(130) * sh_5_1 * z + (1 / 12) * math.sqrt(130) * sh_5_9 * x | |
sh_6_2 = ( | |
-1 / 132 * math.sqrt(286) * sh_5_0 * z | |
+ (1 / 33) * math.sqrt(715) * sh_5_1 * y | |
+ (1 / 132) * math.sqrt(286) * sh_5_10 * x | |
+ (1 / 44) * math.sqrt(1430) * sh_5_2 * z | |
+ (1 / 44) * math.sqrt(1430) * sh_5_8 * x | |
) | |
sh_6_3 = ( | |
-1 / 132 * math.sqrt(858) * sh_5_1 * z | |
+ (1 / 22) * math.sqrt(429) * sh_5_2 * y | |
+ (1 / 22) * math.sqrt(286) * sh_5_3 * z | |
+ (1 / 22) * math.sqrt(286) * sh_5_7 * x | |
+ (1 / 132) * math.sqrt(858) * sh_5_9 * x | |
) | |
sh_6_4 = ( | |
-1 / 66 * math.sqrt(429) * sh_5_2 * z | |
+ (2 / 33) * math.sqrt(286) * sh_5_3 * y | |
+ (1 / 66) * math.sqrt(2002) * sh_5_4 * z | |
+ (1 / 66) * math.sqrt(2002) * sh_5_6 * x | |
+ (1 / 66) * math.sqrt(429) * sh_5_8 * x | |
) | |
sh_6_5 = ( | |
-1 / 66 * math.sqrt(715) * sh_5_3 * z | |
+ (1 / 66) * math.sqrt(5005) * sh_5_4 * y | |
+ (1 / 66) * math.sqrt(3003) * sh_5_5 * x | |
+ (1 / 66) * math.sqrt(715) * sh_5_7 * x | |
) | |
sh_6_6 = -1 / 66 * math.sqrt(2145) * sh_5_4 * x + (1 / 11) * math.sqrt(143) * sh_5_5 * y - 1 / 66 * math.sqrt(2145) * sh_5_6 * z | |
sh_6_7 = ( | |
-1 / 66 * math.sqrt(715) * sh_5_3 * x | |
+ (1 / 66) * math.sqrt(3003) * sh_5_5 * z | |
+ (1 / 66) * math.sqrt(5005) * sh_5_6 * y | |
- 1 / 66 * math.sqrt(715) * sh_5_7 * z | |
) | |
sh_6_8 = ( | |
-1 / 66 * math.sqrt(429) * sh_5_2 * x | |
- 1 / 66 * math.sqrt(2002) * sh_5_4 * x | |
+ (1 / 66) * math.sqrt(2002) * sh_5_6 * z | |
+ (2 / 33) * math.sqrt(286) * sh_5_7 * y | |
- 1 / 66 * math.sqrt(429) * sh_5_8 * z | |
) | |
sh_6_9 = ( | |
-1 / 132 * math.sqrt(858) * sh_5_1 * x | |
- 1 / 22 * math.sqrt(286) * sh_5_3 * x | |
+ (1 / 22) * math.sqrt(286) * sh_5_7 * z | |
+ (1 / 22) * math.sqrt(429) * sh_5_8 * y | |
- 1 / 132 * math.sqrt(858) * sh_5_9 * z | |
) | |
sh_6_10 = ( | |
-1 / 132 * math.sqrt(286) * sh_5_0 * x | |
- 1 / 132 * math.sqrt(286) * sh_5_10 * z | |
- 1 / 44 * math.sqrt(1430) * sh_5_2 * x | |
+ (1 / 44) * math.sqrt(1430) * sh_5_8 * z | |
+ (1 / 33) * math.sqrt(715) * sh_5_9 * y | |
) | |
sh_6_11 = -1 / 12 * math.sqrt(130) * sh_5_1 * x + (1 / 6) * math.sqrt(13) * sh_5_10 * y + (1 / 12) * math.sqrt(130) * sh_5_9 * z | |
sh_6_12 = (1 / 6) * math.sqrt(39) * (-sh_5_0 * x + sh_5_10 * z) | |
if lmax == 6: | |
return torch.stack( | |
[ | |
sh_1_0, | |
sh_1_1, | |
sh_1_2, | |
sh_2_0, | |
sh_2_1, | |
sh_2_2, | |
sh_2_3, | |
sh_2_4, | |
sh_3_0, | |
sh_3_1, | |
sh_3_2, | |
sh_3_3, | |
sh_3_4, | |
sh_3_5, | |
sh_3_6, | |
sh_4_0, | |
sh_4_1, | |
sh_4_2, | |
sh_4_3, | |
sh_4_4, | |
sh_4_5, | |
sh_4_6, | |
sh_4_7, | |
sh_4_8, | |
sh_5_0, | |
sh_5_1, | |
sh_5_2, | |
sh_5_3, | |
sh_5_4, | |
sh_5_5, | |
sh_5_6, | |
sh_5_7, | |
sh_5_8, | |
sh_5_9, | |
sh_5_10, | |
sh_6_0, | |
sh_6_1, | |
sh_6_2, | |
sh_6_3, | |
sh_6_4, | |
sh_6_5, | |
sh_6_6, | |
sh_6_7, | |
sh_6_8, | |
sh_6_9, | |
sh_6_10, | |
sh_6_11, | |
sh_6_12, | |
], | |
dim=-1, | |
) | |
sh_7_0 = (1 / 14) * math.sqrt(210) * (sh_6_0 * z + sh_6_12 * x) | |
sh_7_1 = (1 / 7) * math.sqrt(15) * sh_6_0 * y + (3 / 7) * math.sqrt(5) * sh_6_1 * z + (3 / 7) * math.sqrt(5) * sh_6_11 * x | |
sh_7_2 = ( | |
-1 / 182 * math.sqrt(390) * sh_6_0 * z | |
+ (6 / 91) * math.sqrt(130) * sh_6_1 * y | |
+ (3 / 91) * math.sqrt(715) * sh_6_10 * x | |
+ (1 / 182) * math.sqrt(390) * sh_6_12 * x | |
+ (3 / 91) * math.sqrt(715) * sh_6_2 * z | |
) | |
sh_7_3 = ( | |
-3 / 182 * math.sqrt(130) * sh_6_1 * z | |
+ (3 / 182) * math.sqrt(130) * sh_6_11 * x | |
+ (3 / 91) * math.sqrt(715) * sh_6_2 * y | |
+ (5 / 182) * math.sqrt(858) * sh_6_3 * z | |
+ (5 / 182) * math.sqrt(858) * sh_6_9 * x | |
) | |
sh_7_4 = ( | |
(3 / 91) * math.sqrt(65) * sh_6_10 * x | |
- 3 / 91 * math.sqrt(65) * sh_6_2 * z | |
+ (10 / 91) * math.sqrt(78) * sh_6_3 * y | |
+ (15 / 182) * math.sqrt(78) * sh_6_4 * z | |
+ (15 / 182) * math.sqrt(78) * sh_6_8 * x | |
) | |
sh_7_5 = ( | |
-5 / 91 * math.sqrt(39) * sh_6_3 * z | |
+ (15 / 91) * math.sqrt(39) * sh_6_4 * y | |
+ (3 / 91) * math.sqrt(390) * sh_6_5 * z | |
+ (3 / 91) * math.sqrt(390) * sh_6_7 * x | |
+ (5 / 91) * math.sqrt(39) * sh_6_9 * x | |
) | |
sh_7_6 = ( | |
-15 / 182 * math.sqrt(26) * sh_6_4 * z | |
+ (12 / 91) * math.sqrt(65) * sh_6_5 * y | |
+ (2 / 91) * math.sqrt(1365) * sh_6_6 * x | |
+ (15 / 182) * math.sqrt(26) * sh_6_8 * x | |
) | |
sh_7_7 = -3 / 91 * math.sqrt(455) * sh_6_5 * x + (1 / 13) * math.sqrt(195) * sh_6_6 * y - 3 / 91 * math.sqrt(455) * sh_6_7 * z | |
sh_7_8 = ( | |
-15 / 182 * math.sqrt(26) * sh_6_4 * x | |
+ (2 / 91) * math.sqrt(1365) * sh_6_6 * z | |
+ (12 / 91) * math.sqrt(65) * sh_6_7 * y | |
- 15 / 182 * math.sqrt(26) * sh_6_8 * z | |
) | |
sh_7_9 = ( | |
-5 / 91 * math.sqrt(39) * sh_6_3 * x | |
- 3 / 91 * math.sqrt(390) * sh_6_5 * x | |
+ (3 / 91) * math.sqrt(390) * sh_6_7 * z | |
+ (15 / 91) * math.sqrt(39) * sh_6_8 * y | |
- 5 / 91 * math.sqrt(39) * sh_6_9 * z | |
) | |
sh_7_10 = ( | |
-3 / 91 * math.sqrt(65) * sh_6_10 * z | |
- 3 / 91 * math.sqrt(65) * sh_6_2 * x | |
- 15 / 182 * math.sqrt(78) * sh_6_4 * x | |
+ (15 / 182) * math.sqrt(78) * sh_6_8 * z | |
+ (10 / 91) * math.sqrt(78) * sh_6_9 * y | |
) | |
sh_7_11 = ( | |
-3 / 182 * math.sqrt(130) * sh_6_1 * x | |
+ (3 / 91) * math.sqrt(715) * sh_6_10 * y | |
- 3 / 182 * math.sqrt(130) * sh_6_11 * z | |
- 5 / 182 * math.sqrt(858) * sh_6_3 * x | |
+ (5 / 182) * math.sqrt(858) * sh_6_9 * z | |
) | |
sh_7_12 = ( | |
-1 / 182 * math.sqrt(390) * sh_6_0 * x | |
+ (3 / 91) * math.sqrt(715) * sh_6_10 * z | |
+ (6 / 91) * math.sqrt(130) * sh_6_11 * y | |
- 1 / 182 * math.sqrt(390) * sh_6_12 * z | |
- 3 / 91 * math.sqrt(715) * sh_6_2 * x | |
) | |
sh_7_13 = -3 / 7 * math.sqrt(5) * sh_6_1 * x + (3 / 7) * math.sqrt(5) * sh_6_11 * z + (1 / 7) * math.sqrt(15) * sh_6_12 * y | |
sh_7_14 = (1 / 14) * math.sqrt(210) * (-sh_6_0 * x + sh_6_12 * z) | |
if lmax == 7: | |
return torch.stack( | |
[ | |
sh_1_0, | |
sh_1_1, | |
sh_1_2, | |
sh_2_0, | |
sh_2_1, | |
sh_2_2, | |
sh_2_3, | |
sh_2_4, | |
sh_3_0, | |
sh_3_1, | |
sh_3_2, | |
sh_3_3, | |
sh_3_4, | |
sh_3_5, | |
sh_3_6, | |
sh_4_0, | |
sh_4_1, | |
sh_4_2, | |
sh_4_3, | |
sh_4_4, | |
sh_4_5, | |
sh_4_6, | |
sh_4_7, | |
sh_4_8, | |
sh_5_0, | |
sh_5_1, | |
sh_5_2, | |
sh_5_3, | |
sh_5_4, | |
sh_5_5, | |
sh_5_6, | |
sh_5_7, | |
sh_5_8, | |
sh_5_9, | |
sh_5_10, | |
sh_6_0, | |
sh_6_1, | |
sh_6_2, | |
sh_6_3, | |
sh_6_4, | |
sh_6_5, | |
sh_6_6, | |
sh_6_7, | |
sh_6_8, | |
sh_6_9, | |
sh_6_10, | |
sh_6_11, | |
sh_6_12, | |
sh_7_0, | |
sh_7_1, | |
sh_7_2, | |
sh_7_3, | |
sh_7_4, | |
sh_7_5, | |
sh_7_6, | |
sh_7_7, | |
sh_7_8, | |
sh_7_9, | |
sh_7_10, | |
sh_7_11, | |
sh_7_12, | |
sh_7_13, | |
sh_7_14, | |
], | |
dim=-1, | |
) | |
sh_8_0 = (1 / 4) * math.sqrt(17) * (sh_7_0 * z + sh_7_14 * x) | |
sh_8_1 = (1 / 8) * math.sqrt(17) * sh_7_0 * y + (1 / 16) * math.sqrt(238) * sh_7_1 * z + (1 / 16) * math.sqrt(238) * sh_7_13 * x | |
sh_8_2 = ( | |
-1 / 240 * math.sqrt(510) * sh_7_0 * z | |
+ (1 / 60) * math.sqrt(1785) * sh_7_1 * y | |
+ (1 / 240) * math.sqrt(46410) * sh_7_12 * x | |
+ (1 / 240) * math.sqrt(510) * sh_7_14 * x | |
+ (1 / 240) * math.sqrt(46410) * sh_7_2 * z | |
) | |
sh_8_3 = ( | |
(1 / 80) | |
* math.sqrt(2) | |
* ( | |
-math.sqrt(85) * sh_7_1 * z | |
+ math.sqrt(2210) * sh_7_11 * x | |
+ math.sqrt(85) * sh_7_13 * x | |
+ math.sqrt(2210) * sh_7_2 * y | |
+ math.sqrt(2210) * sh_7_3 * z | |
) | |
) | |
sh_8_4 = ( | |
(1 / 40) * math.sqrt(935) * sh_7_10 * x | |
+ (1 / 40) * math.sqrt(85) * sh_7_12 * x | |
- 1 / 40 * math.sqrt(85) * sh_7_2 * z | |
+ (1 / 10) * math.sqrt(85) * sh_7_3 * y | |
+ (1 / 40) * math.sqrt(935) * sh_7_4 * z | |
) | |
sh_8_5 = ( | |
(1 / 48) | |
* math.sqrt(2) | |
* ( | |
math.sqrt(102) * sh_7_11 * x | |
- math.sqrt(102) * sh_7_3 * z | |
+ math.sqrt(1122) * sh_7_4 * y | |
+ math.sqrt(561) * sh_7_5 * z | |
+ math.sqrt(561) * sh_7_9 * x | |
) | |
) | |
sh_8_6 = ( | |
(1 / 16) * math.sqrt(34) * sh_7_10 * x | |
- 1 / 16 * math.sqrt(34) * sh_7_4 * z | |
+ (1 / 4) * math.sqrt(17) * sh_7_5 * y | |
+ (1 / 16) * math.sqrt(102) * sh_7_6 * z | |
+ (1 / 16) * math.sqrt(102) * sh_7_8 * x | |
) | |
sh_8_7 = ( | |
-1 / 80 * math.sqrt(1190) * sh_7_5 * z | |
+ (1 / 40) * math.sqrt(1785) * sh_7_6 * y | |
+ (1 / 20) * math.sqrt(255) * sh_7_7 * x | |
+ (1 / 80) * math.sqrt(1190) * sh_7_9 * x | |
) | |
sh_8_8 = -1 / 60 * math.sqrt(1785) * sh_7_6 * x + (1 / 15) * math.sqrt(255) * sh_7_7 * y - 1 / 60 * math.sqrt(1785) * sh_7_8 * z | |
sh_8_9 = ( | |
-1 / 80 * math.sqrt(1190) * sh_7_5 * x | |
+ (1 / 20) * math.sqrt(255) * sh_7_7 * z | |
+ (1 / 40) * math.sqrt(1785) * sh_7_8 * y | |
- 1 / 80 * math.sqrt(1190) * sh_7_9 * z | |
) | |
sh_8_10 = ( | |
-1 / 16 * math.sqrt(34) * sh_7_10 * z | |
- 1 / 16 * math.sqrt(34) * sh_7_4 * x | |
- 1 / 16 * math.sqrt(102) * sh_7_6 * x | |
+ (1 / 16) * math.sqrt(102) * sh_7_8 * z | |
+ (1 / 4) * math.sqrt(17) * sh_7_9 * y | |
) | |
sh_8_11 = ( | |
(1 / 48) | |
* math.sqrt(2) | |
* ( | |
math.sqrt(1122) * sh_7_10 * y | |
- math.sqrt(102) * sh_7_11 * z | |
- math.sqrt(102) * sh_7_3 * x | |
- math.sqrt(561) * sh_7_5 * x | |
+ math.sqrt(561) * sh_7_9 * z | |
) | |
) | |
sh_8_12 = ( | |
(1 / 40) * math.sqrt(935) * sh_7_10 * z | |
+ (1 / 10) * math.sqrt(85) * sh_7_11 * y | |
- 1 / 40 * math.sqrt(85) * sh_7_12 * z | |
- 1 / 40 * math.sqrt(85) * sh_7_2 * x | |
- 1 / 40 * math.sqrt(935) * sh_7_4 * x | |
) | |
sh_8_13 = ( | |
(1 / 80) | |
* math.sqrt(2) | |
* ( | |
-math.sqrt(85) * sh_7_1 * x | |
+ math.sqrt(2210) * sh_7_11 * z | |
+ math.sqrt(2210) * sh_7_12 * y | |
- math.sqrt(85) * sh_7_13 * z | |
- math.sqrt(2210) * sh_7_3 * x | |
) | |
) | |
sh_8_14 = ( | |
-1 / 240 * math.sqrt(510) * sh_7_0 * x | |
+ (1 / 240) * math.sqrt(46410) * sh_7_12 * z | |
+ (1 / 60) * math.sqrt(1785) * sh_7_13 * y | |
- 1 / 240 * math.sqrt(510) * sh_7_14 * z | |
- 1 / 240 * math.sqrt(46410) * sh_7_2 * x | |
) | |
sh_8_15 = -1 / 16 * math.sqrt(238) * sh_7_1 * x + (1 / 16) * math.sqrt(238) * sh_7_13 * z + (1 / 8) * math.sqrt(17) * sh_7_14 * y | |
sh_8_16 = (1 / 4) * math.sqrt(17) * (-sh_7_0 * x + sh_7_14 * z) | |
if lmax == 8: | |
return torch.stack( | |
[ | |
sh_1_0, | |
sh_1_1, | |
sh_1_2, | |
sh_2_0, | |
sh_2_1, | |
sh_2_2, | |
sh_2_3, | |
sh_2_4, | |
sh_3_0, | |
sh_3_1, | |
sh_3_2, | |
sh_3_3, | |
sh_3_4, | |
sh_3_5, | |
sh_3_6, | |
sh_4_0, | |
sh_4_1, | |
sh_4_2, | |
sh_4_3, | |
sh_4_4, | |
sh_4_5, | |
sh_4_6, | |
sh_4_7, | |
sh_4_8, | |
sh_5_0, | |
sh_5_1, | |
sh_5_2, | |
sh_5_3, | |
sh_5_4, | |
sh_5_5, | |
sh_5_6, | |
sh_5_7, | |
sh_5_8, | |
sh_5_9, | |
sh_5_10, | |
sh_6_0, | |
sh_6_1, | |
sh_6_2, | |
sh_6_3, | |
sh_6_4, | |
sh_6_5, | |
sh_6_6, | |
sh_6_7, | |
sh_6_8, | |
sh_6_9, | |
sh_6_10, | |
sh_6_11, | |
sh_6_12, | |
sh_7_0, | |
sh_7_1, | |
sh_7_2, | |
sh_7_3, | |
sh_7_4, | |
sh_7_5, | |
sh_7_6, | |
sh_7_7, | |
sh_7_8, | |
sh_7_9, | |
sh_7_10, | |
sh_7_11, | |
sh_7_12, | |
sh_7_13, | |
sh_7_14, | |
sh_8_0, | |
sh_8_1, | |
sh_8_2, | |
sh_8_3, | |
sh_8_4, | |
sh_8_5, | |
sh_8_6, | |
sh_8_7, | |
sh_8_8, | |
sh_8_9, | |
sh_8_10, | |
sh_8_11, | |
sh_8_12, | |
sh_8_13, | |
sh_8_14, | |
sh_8_15, | |
sh_8_16, | |
], | |
dim=-1, | |
) | |
def lmax_tensor_size(lmax): | |
return ((lmax + 1) ** 2) - 1 | |
def get_split_sizes_from_dim(feature_dim): | |
""" | |
Find the lmax value and return split sizes for torch.split based on feature dimension. | |
Args: | |
feature_dim: The dimension of the feature (shape[1] of the tensor) | |
Returns: | |
split_sizes: A list of split sizes for torch.split (sizes of spherical harmonic components) | |
""" | |
lmax = 1 | |
while lmax_tensor_size(lmax) < feature_dim: | |
lmax += 1 | |
if lmax_tensor_size(lmax) != feature_dim: | |
raise ValueError(f"Feature dimension {feature_dim} does not correspond to a valid lmax value") | |
# Return the sizes of each spherical harmonic component | |
return [2 * l + 1 for l in range(1, lmax + 1)] # noqa: E741 | |
class TensorLayerNorm(nn.Module): | |
def __init__(self, hidden_channels, trainable): | |
super(TensorLayerNorm, self).__init__() | |
self.hidden_channels = hidden_channels | |
self.eps = 1e-12 | |
weight = torch.ones(self.hidden_channels) | |
if trainable: | |
self.register_parameter("weight", nn.Parameter(weight)) | |
else: | |
self.register_buffer("weight", weight) | |
self.reset_parameters() | |
def reset_parameters(self): | |
weight = torch.ones(self.hidden_channels) | |
self.weight.data.copy_(weight) | |
def max_min_norm(self, tensor): | |
# Based on VisNet (https://www.nature.com/articles/s41467-023-43720-2) | |
dist = torch.norm(tensor, dim=1, keepdim=True) | |
if (dist == 0).all(): | |
return torch.zeros_like(tensor) | |
dist = dist.clamp(min=self.eps) | |
direct = tensor / dist | |
max_val, _ = torch.max(dist, dim=-1) | |
min_val, _ = torch.min(dist, dim=-1) | |
delta = (max_val - min_val).view(-1) | |
delta = torch.where(delta == 0, torch.ones_like(delta), delta) | |
dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) | |
return F.relu(dist) * direct | |
def forward(self, tensor): | |
# vec: (num_atoms, feature_dim, hidden_channels) | |
feature_dim = tensor.shape[1] | |
try: | |
split_sizes = get_split_sizes_from_dim(feature_dim) | |
except ValueError as e: | |
raise ValueError(f"VecLayerNorm received unsupported feature dimension {feature_dim}: {str(e)}") | |
# Split the vector into parts | |
vec_parts = torch.split(tensor, split_sizes, dim=1) | |
# Normalize each part separately | |
normalized_parts = [self.max_min_norm(part) for part in vec_parts] | |
# Concatenate the normalized parts | |
normalized_vec = torch.cat(normalized_parts, dim=1) | |
# Apply weight | |
return normalized_vec * self.weight.unsqueeze(0).unsqueeze(0) | |
def normalize_string(s: str) -> str: | |
return s.lower().replace("-", "").replace("_", "").replace(" ", "") | |
class Swish(nn.Module): | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, x): | |
return x * torch.sigmoid(x) | |
act_class_mapping = {"ssp": ShiftedSoftplus, "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "swish": Swish} | |
# https://github.com/sunglasses-ai/classy/blob/3e74cba1fdf1b9f9f2ba1cfcfa6c2017aa59fc04/classy/optim/factories.py#L14 | |
def get_activations(optional=False, *args, **kwargs): | |
activations = { | |
normalize_string(act.__name__): act | |
for act in vars(torch.nn.modules.activation).values() | |
if isinstance(act, type) and issubclass(act, torch.nn.Module) | |
} | |
activations.update( | |
{ | |
"relu": torch.nn.ReLU, | |
"elu": torch.nn.ELU, | |
"sigmoid": torch.nn.Sigmoid, | |
"silu": torch.nn.SiLU, | |
"mish": torch.nn.Mish, | |
"swish": torch.nn.SiLU, | |
"selu": torch.nn.SELU, | |
"scaled_swish": scaled_silu, | |
"softplus": shifted_softplus, | |
"slrelu": SmoothLeakyReLU, | |
} | |
) | |
if optional: | |
activations[""] = None | |
return activations | |
def get_activations_none(optional=False, *args, **kwargs): | |
activations = { | |
normalize_string(act.__name__): act | |
for act in vars(torch.nn.modules.activation).values() | |
if isinstance(act, type) and issubclass(act, torch.nn.Module) | |
} | |
activations.update( | |
{ | |
"relu": torch.nn.ReLU, | |
"elu": torch.nn.ELU, | |
"sigmoid": torch.nn.Sigmoid, | |
"silu": torch.nn.SiLU, | |
"selu": torch.nn.SELU, | |
} | |
) | |
if optional: | |
activations[""] = None | |
activations[None] = None | |
return activations | |
def dictionary_to_option(options, selected): | |
if selected not in options: | |
raise ValueError(f'Invalid choice "{selected}", choose one from {", ".join(list(options.keys()))} ') | |
activation = options[selected] | |
if inspect.isclass(activation): | |
activation = activation() | |
return activation | |
def str2act(input_str, *args, **kwargs): | |
if input_str == "": | |
return None | |
act = get_activations(optional=True, *args, **kwargs) | |
out = dictionary_to_option(act, input_str) | |
return out | |
class ExpNormalSmearing(nn.Module): | |
def __init__(self, cutoff=5.0, n_rbf=50, trainable=False): | |
super(ExpNormalSmearing, self).__init__() | |
if isinstance(cutoff, torch.Tensor): | |
cutoff = cutoff.item() | |
self.cutoff = cutoff | |
self.n_rbf = n_rbf | |
self.trainable = trainable | |
self.cutoff_fn = CosineCutoff(cutoff) | |
self.alpha = 5.0 / cutoff | |
means, betas = self._initial_params() | |
if trainable: | |
self.register_parameter("means", nn.Parameter(means)) | |
self.register_parameter("betas", nn.Parameter(betas)) | |
else: | |
self.register_buffer("means", means) | |
self.register_buffer("betas", betas) | |
def _initial_params(self): | |
start_value = torch.exp(torch.scalar_tensor(-self.cutoff)) | |
means = torch.linspace(start_value, 1, self.n_rbf) | |
betas = torch.tensor([(2 / self.n_rbf * (1 - start_value)) ** -2] * self.n_rbf) | |
return means, betas | |
def reset_parameters(self): | |
means, betas = self._initial_params() | |
self.means.data.copy_(means) | |
self.betas.data.copy_(betas) | |
def forward(self, dist): | |
dist = dist.unsqueeze(-1) | |
return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2) | |
def str2basis(input_str): | |
if type(input_str) != str: # noqa: E721 | |
return input_str | |
if input_str == "BesselBasis": | |
radial_basis = BesselBasis | |
elif input_str == "GaussianRBF": | |
radial_basis = GaussianRBF | |
elif input_str.lower() == "expnorm": | |
radial_basis = ExpNormalSmearing | |
else: | |
raise ValueError("Unknown radial basis: {}".format(input_str)) | |
return radial_basis | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
hidden_dims: List[int], | |
bias=True, | |
activation=None, | |
last_activation=None, | |
weight_init=xavier_uniform_, | |
bias_init=zeros_initializer, | |
norm="", | |
): | |
super().__init__() | |
# hidden_dims = [hidden, half, hidden] | |
dims = hidden_dims | |
n_layers = len(dims) | |
DenseMLP = partial(Dense, bias=bias, weight_init=weight_init, bias_init=bias_init) | |
self.dense_layers = nn.ModuleList( | |
[DenseMLP(dims[i], dims[i + 1], activation=activation, norm=norm) for i in range(n_layers - 2)] | |
+ [DenseMLP(dims[-2], dims[-1], activation=last_activation)] | |
) | |
self.layers = nn.Sequential(*self.dense_layers) | |
self.reset_parameters() | |
def reset_parameters(self): | |
for m in self.dense_layers: | |
m.reset_parameters() | |
def forward(self, x): | |
return self.layers(x) | |
class NodeInit(MessagePassing): | |
def __init__( | |
self, | |
hidden_channels, | |
num_rbf, | |
cutoff, | |
max_z=100, | |
activation=F.silu, | |
proj_ln="", | |
last_activation=False, | |
weight_init=nn.init.xavier_uniform_, | |
bias_init=nn.init.zeros_, | |
concat=False, | |
): | |
super(NodeInit, self).__init__(aggr="add") | |
if type(hidden_channels) == int: # noqa: E721 | |
hidden_channels = [hidden_channels] | |
first_channel = hidden_channels[0] | |
last_channel = hidden_channels[-1] | |
DenseInit = partial(Dense, weight_init=weight_init, bias_init=bias_init) # noqa: F841 | |
self.concat = concat | |
self.embedding = nn.Embedding(max_z, last_channel) | |
if self.concat: | |
self.embedding_src = nn.Embedding(max_z, first_channel) | |
self.distance_proj = MLP( | |
[num_rbf + 2 * first_channel] + hidden_channels, | |
activation=activation, | |
norm=proj_ln, | |
weight_init=weight_init, | |
bias_init=bias_init, | |
last_activation=activation if last_activation else None, | |
) | |
else: | |
self.distance_proj = MLP( | |
[num_rbf] + [last_channel], activation=None, norm="", weight_init=weight_init, bias_init=bias_init, last_activation=None | |
) | |
if not self.concat: | |
self.combine = MLP( | |
[2 * last_channel] + hidden_channels, | |
activation=activation, | |
norm=proj_ln, | |
weight_init=weight_init, | |
bias_init=bias_init, | |
last_activation=activation if last_activation else None, | |
) | |
self.cutoff = CosineCutoff(cutoff) | |
self.reset_parameters() | |
def reset_parameters(self): | |
self.embedding.reset_parameters() | |
if self.concat: | |
self.embedding_src.reset_parameters() | |
self.distance_proj.reset_parameters() | |
if not self.concat: | |
self.combine.reset_parameters() | |
def forward(self, z, x, edge_index, edge_weight, edge_attr): | |
# remove self loops | |
mask = edge_index[0] != edge_index[1] | |
if not mask.all(): | |
edge_index = edge_index[:, mask] | |
edge_weight = edge_weight[mask] | |
edge_attr = edge_attr[mask] | |
x_neighbors = self.embedding(z) | |
if not self.concat: | |
C = self.cutoff(edge_weight) | |
W = self.distance_proj(edge_attr) * C.view(-1, 1) | |
x_src = x_neighbors | |
else: | |
x_src = self.embedding_src(z) | |
W = edge_attr | |
# propagate_type: (x: Tensor, s:Tensor, W: Tensor) | |
x_neighbors = self.propagate(edge_index, x=x_neighbors, s=x_src, W=W, size=None) | |
if self.concat: | |
x_neighbors = x + x_neighbors | |
else: | |
x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) | |
return x_neighbors | |
def message(self, s_i, x_j, W): | |
if self.concat: | |
return self.distance_proj(torch.cat([W, x_j, s_i], dim=1)) | |
return x_j * W | |
class EdgeInit(MessagePassing): | |
def __init__( | |
self, | |
num_rbf, | |
hidden_channels, | |
activation=F.silu, | |
proj_ln="", | |
last_activation=False, | |
weight_init=nn.init.xavier_uniform_, | |
bias_init=nn.init.zeros_, | |
): | |
super(EdgeInit, self).__init__(aggr=None) | |
self.activation = activation | |
if type(hidden_channels) == int: # noqa: E721 | |
hidden_channels = [hidden_channels] | |
self.edge_up = MLP( | |
[num_rbf] + hidden_channels, | |
activation=activation, | |
norm=proj_ln, | |
weight_init=weight_init, | |
bias_init=bias_init, | |
last_activation=activation if last_activation else None, | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
self.edge_up.reset_parameters() | |
def forward(self, edge_index, edge_attr, x): | |
# propagate_type: (x: Tensor, edge_attr: Tensor) | |
out = self.propagate(edge_index, x=x, edge_attr=edge_attr) | |
return out | |
def message(self, x_i, x_j, edge_attr): | |
return (x_i + x_j) * self.edge_up(edge_attr) | |
def aggregate(self, features, index): | |
# no aggregate | |
return features | |