ipd's picture
comment out torch_scatter
972c814
"""
This code was adapted from https://github.com/sarpaykent/GotenNet
Copyright (c) 2025 Sarp Aykent
MIT License
GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks
Sarp Aykent and Tian Xia
https://openreview.net/pdf?id=5wxCQDtbMo
"""
from __future__ import absolute_import, division, print_function
import inspect
import math
from functools import partial
from typing import List
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn as nn
from torch.nn.init import constant_, xavier_uniform_
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import glorot_orthogonal
from torch_geometric.nn.models.schnet import ShiftedSoftplus
#from torch_scatter import scatter
zeros_initializer = partial(constant_, val=0.0)
def centralize(
batch,
key: str,
batch_index: torch.Tensor,
): # note: cannot make assumptions on output shape
# derive centroid of each batch element, and center entities using corresponding centroids
entities_centroid = scatter(batch[key], batch_index, dim=0, reduce="mean") # e.g., [batch_size, 3]
entities_centered = batch[key] - entities_centroid[batch_index]
return entities_centroid, entities_centered
def decentralize(
positions: torch.Tensor,
batch_index: torch.Tensor,
entities_centroid: torch.Tensor,
) -> torch.Tensor: # note: cannot make assumptions on output shape
entities_centered = positions + entities_centroid[batch_index]
return entities_centered
def parse_update_info(edge_updates):
update_info = {
"gated": False,
"rej": True,
"vec_norm": False,
"mlp": False,
"mlpa": False,
"lin_w": 0,
"drej": False,
}
if isinstance(edge_updates, str):
update_parts = edge_updates.split("_")
else:
update_parts = []
allowed_parts = ["gated", "gatedt", "norej", "mlp", "mlpa", "act", "linw", "linwa", "drej"]
if not all([part in allowed_parts for part in update_parts]):
raise ValueError(f"Invalid edge update parts. Allowed parts are {allowed_parts}")
if "gated" in update_parts:
update_info["gated"] = "gated"
if "gatedt" in update_parts:
update_info["gated"] = "gatedt"
if "act" in update_parts:
update_info["gated"] = "act"
if "norej" in update_parts:
update_info["rej"] = False
if "mlp" in update_parts:
update_info["mlp"] = True
if "mlpa" in update_parts:
update_info["mlpa"] = True
if "linw" in update_parts:
update_info["lin_w"] = 1
if "linwa" in update_parts:
update_info["lin_w"] = 2
if "drej" in update_parts:
update_info["drej"] = True
return update_info
class SmoothLeakyReLU(torch.nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.alpha = negative_slope
def forward(self, x):
x1 = ((1 + self.alpha) / 2) * x
x2 = ((1 - self.alpha) / 2) * x * (2 * torch.sigmoid(x) - 1)
return x1 + x2
def extra_repr(self):
return "negative_slope={}".format(self.alpha)
def shifted_softplus(x: torch.Tensor):
return F.softplus(x) - math.log(2.0)
class PolynomialCutoff(nn.Module):
def __init__(self, cutoff, p: int = 6):
super(PolynomialCutoff, self).__init__()
self.cutoff = cutoff
self.p = p
@staticmethod
def polynomial_cutoff(r: Tensor, rcut: float, p: float = 6.0) -> Tensor:
"""
Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
"""
if not p >= 2.0:
# replace below with logger error
print(f"Exponent p={p} has to be >= 2.")
print("Exiting code.")
print(f"Exponent p={p} has to be >= 2.")
print("Exiting code.")
exit()
rscaled = r / rcut
out = 1.0
out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(rscaled, p))
out = out + (p * (p + 2.0) * torch.pow(rscaled, p + 1.0))
out = out - ((p * (p + 1.0) / 2) * torch.pow(rscaled, p + 2.0))
return out * (rscaled < 1.0).float()
def forward(self, r):
return self.polynomial_cutoff(r=r, rcut=self.cutoff, p=self.p)
def __repr__(self):
return f"{self.__class__.__name__}(cutoff={self.cutoff}, p={self.p})"
class CosineCutoff(nn.Module):
def __init__(self, cutoff):
super(CosineCutoff, self).__init__()
if isinstance(cutoff, torch.Tensor):
cutoff = cutoff.item()
self.cutoff = cutoff
def forward(self, distances):
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0)
cutoffs = cutoffs * (distances < self.cutoff).float()
return cutoffs
class ScaleShift(nn.Module):
r"""Scale and shift layer for standardization.
.. math::
y = x \times \sigma + \mu
Args:
mean (torch.Tensor): mean value :math:`\mu`.
stddev (torch.Tensor): standard deviation value :math:`\sigma`.
"""
def __init__(self, mean, stddev):
super(ScaleShift, self).__init__()
if isinstance(mean, float):
mean = torch.FloatTensor([mean])
if isinstance(stddev, float):
stddev = torch.FloatTensor([stddev])
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
def forward(self, input):
"""Compute layer output.
Args:
input (torch.Tensor): input data.
Returns:
torch.Tensor: layer output.
"""
y = input * self.stddev + self.mean
return y
class GetItem(nn.Module):
"""Extraction layer to get an item from SchNetPack dictionary of input tensors.
Args:
key (str): Property to be extracted from SchNetPack input tensors.
"""
def __init__(self, key):
super(GetItem, self).__init__()
self.key = key
def forward(self, inputs):
"""Compute layer output.
Args:
inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.
Returns:
torch.Tensor: layer output.
"""
return inputs[self.key]
class SchnetMLP(nn.Module):
"""Multiple layer fully connected perceptron neural network.
Args:
n_in (int): number of input nodes.
n_out (int): number of output nodes.
n_hidden (list of int or int, optional): number hidden layer nodes.
If an integer, same number of node is used for all hidden layers resulting
in a rectangular network.
If None, the number of neurons is divided by two after each layer starting
n_in resulting in a pyramidal network.
n_layers (int, optional): number of layers.
activation (callable, optional): activation function. All hidden layers would
the same activation function except the output layer that does not apply
any activation function.
"""
def __init__(self, n_in, n_out, n_hidden=None, n_layers=2, activation=shifted_softplus):
super(SchnetMLP, self).__init__()
# get list of number of nodes in input, hidden & output layers
if n_hidden is None:
c_neurons = n_in
self.n_neurons = []
for i in range(n_layers):
self.n_neurons.append(c_neurons)
c_neurons = c_neurons // 2
self.n_neurons.append(n_out)
else:
# get list of number of nodes hidden layers
if type(n_hidden) is int:
n_hidden = [n_hidden] * (n_layers - 1)
self.n_neurons = [n_in] + n_hidden + [n_out]
# assign a Dense layer (with activation function) to each hidden layer
layers = [Dense(self.n_neurons[i], self.n_neurons[i + 1], activation=activation) for i in range(n_layers - 1)]
# assign a Dense layer (without activation function) to the output layer
layers.append(Dense(self.n_neurons[-2], self.n_neurons[-1], activation=None))
# put all layers together to make the network
self.out_net = nn.Sequential(*layers)
def forward(self, inputs):
"""Compute neural network output.
Args:
inputs (torch.Tensor): network input.
Returns:
torch.Tensor: network output.
"""
return self.out_net(inputs)
def scaled_silu(x, scale=0.6):
return F.silu(x) * scale
def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor):
coeff = -0.5 / torch.pow(widths, 2)
diff = inputs[..., None] - offsets
y = torch.exp(coeff * torch.pow(diff, 2))
return y
class GaussianRBF(nn.Module):
r"""Gaussian radial basis functions."""
def __init__(self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False):
"""
Args:
n_rbf: total number of Gaussian functions, :math:`N_g`.
cutoff: center of last Gaussian function, :math:`\mu_{N_g}`
start: center of first Gaussian function, :math:`\mu_0`.
trainable: If True, widths and offset of Gaussian functions
are adjusted during training process.
"""
super(GaussianRBF, self).__init__()
self.n_rbf = n_rbf
# compute offset and width of Gaussian functions
offset = torch.linspace(start, cutoff, n_rbf)
widths = torch.FloatTensor(torch.abs(offset[1] - offset[0]) * torch.ones_like(offset))
if trainable:
self.widths = nn.Parameter(widths)
self.offsets = nn.Parameter(offset)
else:
self.register_buffer("widths", widths)
self.register_buffer("offsets", offset)
def forward(self, inputs: torch.Tensor):
return gaussian_rbf(inputs, self.offsets, self.widths)
class BesselBasis(nn.Module):
"""
Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
"""
def __init__(self, cutoff=5.0, n_rbf=None, trainable=False):
"""
Args:
cutoff: radial cutoff
n_rbf: number of basis functions.
"""
super(BesselBasis, self).__init__()
self.n_rbf = n_rbf
# compute offset and width of Gaussian functions
freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
self.register_buffer("freqs", freqs)
self.register_buffer("norm1", torch.tensor(1.0))
def forward(self, inputs):
input_size = len(inputs.shape) # noqa: F841
a = self.freqs[None, :]
inputs = inputs[..., None]
ax = inputs * a
sinax = torch.sin(ax)
norm = torch.where(inputs == 0, self.norm1, inputs)
y = sinax / norm
return y
def glorot_orthogonal_wrapper_(tensor, scale=2.0):
return glorot_orthogonal(tensor, scale=scale)
def _standardize(kernel):
"""
Makes sure that Var(W) = 1 and E[W] = 0
"""
eps = 1e-6
if len(kernel.shape) == 3:
axis = [0, 1] # last dimension is output dimension
else:
axis = 1
var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True)
kernel = (kernel - mean) / (var + eps) ** 0.5
return kernel
def he_orthogonal_init(tensor):
"""
Generate a weight matrix with variance according to He initialization.
Based on a random (semi-)orthogonal matrix neural networks
are expected to learn better when features are decorrelated
(stated by eg. "Reducing overfitting in deep networks by decorrelating representations",
"Dropout: a simple way to prevent neural networks from overfitting",
"Exact solutions to the nonlinear dynamics of learning in deep linear neural networks")
"""
tensor = torch.nn.init.orthogonal_(tensor)
if len(tensor.shape) == 3:
fan_in = tensor.shape[:-1].numel()
else:
fan_in = tensor.shape[1]
with torch.no_grad():
tensor.data = _standardize(tensor.data)
tensor.data *= (1 / fan_in) ** 0.5
return tensor
def get_weight_init_by_string(init_str):
if init_str == "":
# Noop
return lambda x: x
elif init_str == "zeros":
return torch.nn.init.zeros_
elif init_str == "xavier_uniform":
return torch.nn.init.xavier_uniform_
elif init_str == "glo_orthogonal":
return glorot_orthogonal_wrapper_
elif init_str == "he_orthogonal":
return he_orthogonal_init
else:
raise ValueError(f"Unknown initialization {init_str}")
class Dense(nn.Linear):
r"""Fully connected linear layer with activation function.
Barrowed from https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/nn/base.py
.. math::
y = activation(xW^T + b)
Args:
in_features (int): number of input feature :math:`x`.
out_features (int): number of output features :math:`y`.
bias (bool, optional): if False, the layer will not adapt bias :math:`b`.
activation (callable, optional): if None, no activation function is used.
weight_init (callable, optional): weight initializer from current weight.
bias_init (callable, optional): bias initializer from current bias.
"""
def __init__(
self,
in_features,
out_features,
bias=True,
activation=None,
weight_init=xavier_uniform_,
bias_init=zeros_initializer,
norm=None,
gain=None,
):
# initialize linear layer y = xW^T + b
self.weight_init = weight_init
self.bias_init = bias_init
self.gain = gain
super(Dense, self).__init__(in_features, out_features, bias)
# Initialize activation function
if inspect.isclass(activation):
self.activation = activation()
self.activation = activation
if norm == "layer":
self.norm = nn.LayerNorm(out_features)
elif norm == "batch":
self.norm = nn.BatchNorm1d(out_features)
elif norm == "instance":
self.norm = nn.InstanceNorm1d(out_features)
else:
self.norm = None
def reset_parameters(self):
"""Reinitialize model weight and bias values."""
if self.gain:
self.weight_init(self.weight, gain=self.gain)
else:
self.weight_init(self.weight)
if self.bias is not None:
self.bias_init(self.bias)
def forward(self, inputs):
"""Compute layer output.
Args:
inputs (dict of torch.Tensor): batch of input values.
Returns:
torch.Tensor: layer output.
"""
# compute linear layer y = xW^T + b
y = super(Dense, self).forward(inputs)
if self.norm is not None:
y = self.norm(y)
# add activation function
if self.activation:
y = self.activation(y)
return y
class _VDropout(nn.Module):
"""
Vector channel dropout where the elements of each
vector channel are dropped together.
"""
def __init__(self, drop_rate, scale=True):
super(_VDropout, self).__init__()
self.drop_rate = drop_rate
self.scale = scale
def forward(self, x, dim=-1):
"""
:param x: `torch.Tensor` corresponding to vector channels
"""
if self.drop_rate == 0:
return x
device = x.device
if not self.training:
return x
shape = list(x.shape)
assert shape[dim] == 3, "The dimension must be vector"
shape[dim] = 1
mask = torch.bernoulli((1 - self.drop_rate) * torch.ones(shape, device=device))
x = mask * x
if self.scale:
# scale the output to keep the expected output distribution
# same as input distribution. However, this might be harmfuk
# for vector space.
x = x / (1 - self.drop_rate)
return x
class Dropout(nn.Module):
"""
Combined dropout for tuples (s, V).
Takes tuples (s, V) as input and as output.
"""
def __init__(self, drop_rate, vector_dropout=True):
super(Dropout, self).__init__()
self.sdropout = nn.Dropout(drop_rate)
if vector_dropout:
self.vdropout = _VDropout(drop_rate)
else:
self.vdropout = lambda x, dim: x
def forward(self, x):
"""
:param x: tuple (s, V) of `torch.Tensor`,
or single `torch.Tensor`
(will be assumed to be scalar channels)
"""
if type(x) is torch.Tensor:
return self.sdropout(x)
s, v = x
return self.sdropout(s), self.vdropout(v, dim=1)
class TensorInit(nn.Module):
def __init__(self, l=2): # noqa: E741
super(TensorInit, self).__init__()
self.l = l
def forward(self, edge_vec):
edge_sh = self._calculate_components(self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2])
return edge_sh
@property
def tensor_size(self):
return ((self.l + 1) ** 2) - 1
@staticmethod
def _calculate_components(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
sh_1_0, sh_1_1, sh_1_2 = x, y, z
if lmax == 1:
return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1)
# (x^2, y^2, z^2) ^2
sh_2_0 = math.sqrt(3.0) * x * z
sh_2_1 = math.sqrt(3.0) * x * y
y2 = y.pow(2)
x2z2 = x.pow(2) + z.pow(2)
sh_2_2 = y2 - 0.5 * x2z2
sh_2_3 = math.sqrt(3.0) * y * z
sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2))
if lmax == 2:
return torch.stack([sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1)
# Borrowed from e3nn: https://github.com/e3nn/e3nn/blob/main/e3nn/o3/_spherical_harmonics.py#L188
sh_3_0 = (1 / 6) * math.sqrt(42) * (sh_2_0 * z + sh_2_4 * x)
sh_3_1 = math.sqrt(7) * sh_2_0 * y
sh_3_2 = (1 / 8) * math.sqrt(168) * (4.0 * y2 - x2z2) * x
sh_3_3 = (1 / 2) * math.sqrt(7) * y * (2.0 * y2 - 3.0 * x2z2)
sh_3_4 = (1 / 8) * math.sqrt(168) * z * (4.0 * y2 - x2z2)
sh_3_5 = math.sqrt(7) * sh_2_4 * y
sh_3_6 = (1 / 6) * math.sqrt(42) * (sh_2_4 * z - sh_2_0 * x)
if lmax == 3:
return torch.stack(
[
sh_1_0,
sh_1_1,
sh_1_2,
sh_2_0,
sh_2_1,
sh_2_2,
sh_2_3,
sh_2_4,
sh_3_0,
sh_3_1,
sh_3_2,
sh_3_3,
sh_3_4,
sh_3_5,
sh_3_6,
],
dim=-1,
)
sh_4_0 = (3 / 4) * math.sqrt(2) * (sh_3_0 * z + sh_3_6 * x)
sh_4_1 = (3 / 4) * sh_3_0 * y + (3 / 8) * math.sqrt(6) * sh_3_1 * z + (3 / 8) * math.sqrt(6) * sh_3_5 * x
sh_4_2 = (
-3 / 56 * math.sqrt(14) * sh_3_0 * z
+ (3 / 14) * math.sqrt(21) * sh_3_1 * y
+ (3 / 56) * math.sqrt(210) * sh_3_2 * z
+ (3 / 56) * math.sqrt(210) * sh_3_4 * x
+ (3 / 56) * math.sqrt(14) * sh_3_6 * x
)
sh_4_3 = (
-3 / 56 * math.sqrt(42) * sh_3_1 * z
+ (3 / 28) * math.sqrt(105) * sh_3_2 * y
+ (3 / 28) * math.sqrt(70) * sh_3_3 * x
+ (3 / 56) * math.sqrt(42) * sh_3_5 * x
)
sh_4_4 = -3 / 28 * math.sqrt(42) * sh_3_2 * x + (3 / 7) * math.sqrt(7) * sh_3_3 * y - 3 / 28 * math.sqrt(42) * sh_3_4 * z
sh_4_5 = (
-3 / 56 * math.sqrt(42) * sh_3_1 * x
+ (3 / 28) * math.sqrt(70) * sh_3_3 * z
+ (3 / 28) * math.sqrt(105) * sh_3_4 * y
- 3 / 56 * math.sqrt(42) * sh_3_5 * z
)
sh_4_6 = (
-3 / 56 * math.sqrt(14) * sh_3_0 * x
- 3 / 56 * math.sqrt(210) * sh_3_2 * x
+ (3 / 56) * math.sqrt(210) * sh_3_4 * z
+ (3 / 14) * math.sqrt(21) * sh_3_5 * y
- 3 / 56 * math.sqrt(14) * sh_3_6 * z
)
sh_4_7 = -3 / 8 * math.sqrt(6) * sh_3_1 * x + (3 / 8) * math.sqrt(6) * sh_3_5 * z + (3 / 4) * sh_3_6 * y
sh_4_8 = (3 / 4) * math.sqrt(2) * (-sh_3_0 * x + sh_3_6 * z)
if lmax == 4:
return torch.stack(
[
sh_1_0,
sh_1_1,
sh_1_2,
sh_2_0,
sh_2_1,
sh_2_2,
sh_2_3,
sh_2_4,
sh_3_0,
sh_3_1,
sh_3_2,
sh_3_3,
sh_3_4,
sh_3_5,
sh_3_6,
sh_4_0,
sh_4_1,
sh_4_2,
sh_4_3,
sh_4_4,
sh_4_5,
sh_4_6,
sh_4_7,
sh_4_8,
],
dim=-1,
)
sh_5_0 = (1 / 10) * math.sqrt(110) * (sh_4_0 * z + sh_4_8 * x)
sh_5_1 = (1 / 5) * math.sqrt(11) * sh_4_0 * y + (1 / 5) * math.sqrt(22) * sh_4_1 * z + (1 / 5) * math.sqrt(22) * sh_4_7 * x
sh_5_2 = (
-1 / 30 * math.sqrt(22) * sh_4_0 * z
+ (4 / 15) * math.sqrt(11) * sh_4_1 * y
+ (1 / 15) * math.sqrt(154) * sh_4_2 * z
+ (1 / 15) * math.sqrt(154) * sh_4_6 * x
+ (1 / 30) * math.sqrt(22) * sh_4_8 * x
)
sh_5_3 = (
-1 / 30 * math.sqrt(66) * sh_4_1 * z
+ (1 / 15) * math.sqrt(231) * sh_4_2 * y
+ (1 / 30) * math.sqrt(462) * sh_4_3 * z
+ (1 / 30) * math.sqrt(462) * sh_4_5 * x
+ (1 / 30) * math.sqrt(66) * sh_4_7 * x
)
sh_5_4 = (
-1 / 15 * math.sqrt(33) * sh_4_2 * z
+ (2 / 15) * math.sqrt(66) * sh_4_3 * y
+ (1 / 15) * math.sqrt(165) * sh_4_4 * x
+ (1 / 15) * math.sqrt(33) * sh_4_6 * x
)
sh_5_5 = -1 / 15 * math.sqrt(110) * sh_4_3 * x + (1 / 3) * math.sqrt(11) * sh_4_4 * y - 1 / 15 * math.sqrt(110) * sh_4_5 * z
sh_5_6 = (
-1 / 15 * math.sqrt(33) * sh_4_2 * x
+ (1 / 15) * math.sqrt(165) * sh_4_4 * z
+ (2 / 15) * math.sqrt(66) * sh_4_5 * y
- 1 / 15 * math.sqrt(33) * sh_4_6 * z
)
sh_5_7 = (
-1 / 30 * math.sqrt(66) * sh_4_1 * x
- 1 / 30 * math.sqrt(462) * sh_4_3 * x
+ (1 / 30) * math.sqrt(462) * sh_4_5 * z
+ (1 / 15) * math.sqrt(231) * sh_4_6 * y
- 1 / 30 * math.sqrt(66) * sh_4_7 * z
)
sh_5_8 = (
-1 / 30 * math.sqrt(22) * sh_4_0 * x
- 1 / 15 * math.sqrt(154) * sh_4_2 * x
+ (1 / 15) * math.sqrt(154) * sh_4_6 * z
+ (4 / 15) * math.sqrt(11) * sh_4_7 * y
- 1 / 30 * math.sqrt(22) * sh_4_8 * z
)
sh_5_9 = -1 / 5 * math.sqrt(22) * sh_4_1 * x + (1 / 5) * math.sqrt(22) * sh_4_7 * z + (1 / 5) * math.sqrt(11) * sh_4_8 * y
sh_5_10 = (1 / 10) * math.sqrt(110) * (-sh_4_0 * x + sh_4_8 * z)
if lmax == 5:
return torch.stack(
[
sh_1_0,
sh_1_1,
sh_1_2,
sh_2_0,
sh_2_1,
sh_2_2,
sh_2_3,
sh_2_4,
sh_3_0,
sh_3_1,
sh_3_2,
sh_3_3,
sh_3_4,
sh_3_5,
sh_3_6,
sh_4_0,
sh_4_1,
sh_4_2,
sh_4_3,
sh_4_4,
sh_4_5,
sh_4_6,
sh_4_7,
sh_4_8,
sh_5_0,
sh_5_1,
sh_5_2,
sh_5_3,
sh_5_4,
sh_5_5,
sh_5_6,
sh_5_7,
sh_5_8,
sh_5_9,
sh_5_10,
],
dim=-1,
)
sh_6_0 = (1 / 6) * math.sqrt(39) * (sh_5_0 * z + sh_5_10 * x)
sh_6_1 = (1 / 6) * math.sqrt(13) * sh_5_0 * y + (1 / 12) * math.sqrt(130) * sh_5_1 * z + (1 / 12) * math.sqrt(130) * sh_5_9 * x
sh_6_2 = (
-1 / 132 * math.sqrt(286) * sh_5_0 * z
+ (1 / 33) * math.sqrt(715) * sh_5_1 * y
+ (1 / 132) * math.sqrt(286) * sh_5_10 * x
+ (1 / 44) * math.sqrt(1430) * sh_5_2 * z
+ (1 / 44) * math.sqrt(1430) * sh_5_8 * x
)
sh_6_3 = (
-1 / 132 * math.sqrt(858) * sh_5_1 * z
+ (1 / 22) * math.sqrt(429) * sh_5_2 * y
+ (1 / 22) * math.sqrt(286) * sh_5_3 * z
+ (1 / 22) * math.sqrt(286) * sh_5_7 * x
+ (1 / 132) * math.sqrt(858) * sh_5_9 * x
)
sh_6_4 = (
-1 / 66 * math.sqrt(429) * sh_5_2 * z
+ (2 / 33) * math.sqrt(286) * sh_5_3 * y
+ (1 / 66) * math.sqrt(2002) * sh_5_4 * z
+ (1 / 66) * math.sqrt(2002) * sh_5_6 * x
+ (1 / 66) * math.sqrt(429) * sh_5_8 * x
)
sh_6_5 = (
-1 / 66 * math.sqrt(715) * sh_5_3 * z
+ (1 / 66) * math.sqrt(5005) * sh_5_4 * y
+ (1 / 66) * math.sqrt(3003) * sh_5_5 * x
+ (1 / 66) * math.sqrt(715) * sh_5_7 * x
)
sh_6_6 = -1 / 66 * math.sqrt(2145) * sh_5_4 * x + (1 / 11) * math.sqrt(143) * sh_5_5 * y - 1 / 66 * math.sqrt(2145) * sh_5_6 * z
sh_6_7 = (
-1 / 66 * math.sqrt(715) * sh_5_3 * x
+ (1 / 66) * math.sqrt(3003) * sh_5_5 * z
+ (1 / 66) * math.sqrt(5005) * sh_5_6 * y
- 1 / 66 * math.sqrt(715) * sh_5_7 * z
)
sh_6_8 = (
-1 / 66 * math.sqrt(429) * sh_5_2 * x
- 1 / 66 * math.sqrt(2002) * sh_5_4 * x
+ (1 / 66) * math.sqrt(2002) * sh_5_6 * z
+ (2 / 33) * math.sqrt(286) * sh_5_7 * y
- 1 / 66 * math.sqrt(429) * sh_5_8 * z
)
sh_6_9 = (
-1 / 132 * math.sqrt(858) * sh_5_1 * x
- 1 / 22 * math.sqrt(286) * sh_5_3 * x
+ (1 / 22) * math.sqrt(286) * sh_5_7 * z
+ (1 / 22) * math.sqrt(429) * sh_5_8 * y
- 1 / 132 * math.sqrt(858) * sh_5_9 * z
)
sh_6_10 = (
-1 / 132 * math.sqrt(286) * sh_5_0 * x
- 1 / 132 * math.sqrt(286) * sh_5_10 * z
- 1 / 44 * math.sqrt(1430) * sh_5_2 * x
+ (1 / 44) * math.sqrt(1430) * sh_5_8 * z
+ (1 / 33) * math.sqrt(715) * sh_5_9 * y
)
sh_6_11 = -1 / 12 * math.sqrt(130) * sh_5_1 * x + (1 / 6) * math.sqrt(13) * sh_5_10 * y + (1 / 12) * math.sqrt(130) * sh_5_9 * z
sh_6_12 = (1 / 6) * math.sqrt(39) * (-sh_5_0 * x + sh_5_10 * z)
if lmax == 6:
return torch.stack(
[
sh_1_0,
sh_1_1,
sh_1_2,
sh_2_0,
sh_2_1,
sh_2_2,
sh_2_3,
sh_2_4,
sh_3_0,
sh_3_1,
sh_3_2,
sh_3_3,
sh_3_4,
sh_3_5,
sh_3_6,
sh_4_0,
sh_4_1,
sh_4_2,
sh_4_3,
sh_4_4,
sh_4_5,
sh_4_6,
sh_4_7,
sh_4_8,
sh_5_0,
sh_5_1,
sh_5_2,
sh_5_3,
sh_5_4,
sh_5_5,
sh_5_6,
sh_5_7,
sh_5_8,
sh_5_9,
sh_5_10,
sh_6_0,
sh_6_1,
sh_6_2,
sh_6_3,
sh_6_4,
sh_6_5,
sh_6_6,
sh_6_7,
sh_6_8,
sh_6_9,
sh_6_10,
sh_6_11,
sh_6_12,
],
dim=-1,
)
sh_7_0 = (1 / 14) * math.sqrt(210) * (sh_6_0 * z + sh_6_12 * x)
sh_7_1 = (1 / 7) * math.sqrt(15) * sh_6_0 * y + (3 / 7) * math.sqrt(5) * sh_6_1 * z + (3 / 7) * math.sqrt(5) * sh_6_11 * x
sh_7_2 = (
-1 / 182 * math.sqrt(390) * sh_6_0 * z
+ (6 / 91) * math.sqrt(130) * sh_6_1 * y
+ (3 / 91) * math.sqrt(715) * sh_6_10 * x
+ (1 / 182) * math.sqrt(390) * sh_6_12 * x
+ (3 / 91) * math.sqrt(715) * sh_6_2 * z
)
sh_7_3 = (
-3 / 182 * math.sqrt(130) * sh_6_1 * z
+ (3 / 182) * math.sqrt(130) * sh_6_11 * x
+ (3 / 91) * math.sqrt(715) * sh_6_2 * y
+ (5 / 182) * math.sqrt(858) * sh_6_3 * z
+ (5 / 182) * math.sqrt(858) * sh_6_9 * x
)
sh_7_4 = (
(3 / 91) * math.sqrt(65) * sh_6_10 * x
- 3 / 91 * math.sqrt(65) * sh_6_2 * z
+ (10 / 91) * math.sqrt(78) * sh_6_3 * y
+ (15 / 182) * math.sqrt(78) * sh_6_4 * z
+ (15 / 182) * math.sqrt(78) * sh_6_8 * x
)
sh_7_5 = (
-5 / 91 * math.sqrt(39) * sh_6_3 * z
+ (15 / 91) * math.sqrt(39) * sh_6_4 * y
+ (3 / 91) * math.sqrt(390) * sh_6_5 * z
+ (3 / 91) * math.sqrt(390) * sh_6_7 * x
+ (5 / 91) * math.sqrt(39) * sh_6_9 * x
)
sh_7_6 = (
-15 / 182 * math.sqrt(26) * sh_6_4 * z
+ (12 / 91) * math.sqrt(65) * sh_6_5 * y
+ (2 / 91) * math.sqrt(1365) * sh_6_6 * x
+ (15 / 182) * math.sqrt(26) * sh_6_8 * x
)
sh_7_7 = -3 / 91 * math.sqrt(455) * sh_6_5 * x + (1 / 13) * math.sqrt(195) * sh_6_6 * y - 3 / 91 * math.sqrt(455) * sh_6_7 * z
sh_7_8 = (
-15 / 182 * math.sqrt(26) * sh_6_4 * x
+ (2 / 91) * math.sqrt(1365) * sh_6_6 * z
+ (12 / 91) * math.sqrt(65) * sh_6_7 * y
- 15 / 182 * math.sqrt(26) * sh_6_8 * z
)
sh_7_9 = (
-5 / 91 * math.sqrt(39) * sh_6_3 * x
- 3 / 91 * math.sqrt(390) * sh_6_5 * x
+ (3 / 91) * math.sqrt(390) * sh_6_7 * z
+ (15 / 91) * math.sqrt(39) * sh_6_8 * y
- 5 / 91 * math.sqrt(39) * sh_6_9 * z
)
sh_7_10 = (
-3 / 91 * math.sqrt(65) * sh_6_10 * z
- 3 / 91 * math.sqrt(65) * sh_6_2 * x
- 15 / 182 * math.sqrt(78) * sh_6_4 * x
+ (15 / 182) * math.sqrt(78) * sh_6_8 * z
+ (10 / 91) * math.sqrt(78) * sh_6_9 * y
)
sh_7_11 = (
-3 / 182 * math.sqrt(130) * sh_6_1 * x
+ (3 / 91) * math.sqrt(715) * sh_6_10 * y
- 3 / 182 * math.sqrt(130) * sh_6_11 * z
- 5 / 182 * math.sqrt(858) * sh_6_3 * x
+ (5 / 182) * math.sqrt(858) * sh_6_9 * z
)
sh_7_12 = (
-1 / 182 * math.sqrt(390) * sh_6_0 * x
+ (3 / 91) * math.sqrt(715) * sh_6_10 * z
+ (6 / 91) * math.sqrt(130) * sh_6_11 * y
- 1 / 182 * math.sqrt(390) * sh_6_12 * z
- 3 / 91 * math.sqrt(715) * sh_6_2 * x
)
sh_7_13 = -3 / 7 * math.sqrt(5) * sh_6_1 * x + (3 / 7) * math.sqrt(5) * sh_6_11 * z + (1 / 7) * math.sqrt(15) * sh_6_12 * y
sh_7_14 = (1 / 14) * math.sqrt(210) * (-sh_6_0 * x + sh_6_12 * z)
if lmax == 7:
return torch.stack(
[
sh_1_0,
sh_1_1,
sh_1_2,
sh_2_0,
sh_2_1,
sh_2_2,
sh_2_3,
sh_2_4,
sh_3_0,
sh_3_1,
sh_3_2,
sh_3_3,
sh_3_4,
sh_3_5,
sh_3_6,
sh_4_0,
sh_4_1,
sh_4_2,
sh_4_3,
sh_4_4,
sh_4_5,
sh_4_6,
sh_4_7,
sh_4_8,
sh_5_0,
sh_5_1,
sh_5_2,
sh_5_3,
sh_5_4,
sh_5_5,
sh_5_6,
sh_5_7,
sh_5_8,
sh_5_9,
sh_5_10,
sh_6_0,
sh_6_1,
sh_6_2,
sh_6_3,
sh_6_4,
sh_6_5,
sh_6_6,
sh_6_7,
sh_6_8,
sh_6_9,
sh_6_10,
sh_6_11,
sh_6_12,
sh_7_0,
sh_7_1,
sh_7_2,
sh_7_3,
sh_7_4,
sh_7_5,
sh_7_6,
sh_7_7,
sh_7_8,
sh_7_9,
sh_7_10,
sh_7_11,
sh_7_12,
sh_7_13,
sh_7_14,
],
dim=-1,
)
sh_8_0 = (1 / 4) * math.sqrt(17) * (sh_7_0 * z + sh_7_14 * x)
sh_8_1 = (1 / 8) * math.sqrt(17) * sh_7_0 * y + (1 / 16) * math.sqrt(238) * sh_7_1 * z + (1 / 16) * math.sqrt(238) * sh_7_13 * x
sh_8_2 = (
-1 / 240 * math.sqrt(510) * sh_7_0 * z
+ (1 / 60) * math.sqrt(1785) * sh_7_1 * y
+ (1 / 240) * math.sqrt(46410) * sh_7_12 * x
+ (1 / 240) * math.sqrt(510) * sh_7_14 * x
+ (1 / 240) * math.sqrt(46410) * sh_7_2 * z
)
sh_8_3 = (
(1 / 80)
* math.sqrt(2)
* (
-math.sqrt(85) * sh_7_1 * z
+ math.sqrt(2210) * sh_7_11 * x
+ math.sqrt(85) * sh_7_13 * x
+ math.sqrt(2210) * sh_7_2 * y
+ math.sqrt(2210) * sh_7_3 * z
)
)
sh_8_4 = (
(1 / 40) * math.sqrt(935) * sh_7_10 * x
+ (1 / 40) * math.sqrt(85) * sh_7_12 * x
- 1 / 40 * math.sqrt(85) * sh_7_2 * z
+ (1 / 10) * math.sqrt(85) * sh_7_3 * y
+ (1 / 40) * math.sqrt(935) * sh_7_4 * z
)
sh_8_5 = (
(1 / 48)
* math.sqrt(2)
* (
math.sqrt(102) * sh_7_11 * x
- math.sqrt(102) * sh_7_3 * z
+ math.sqrt(1122) * sh_7_4 * y
+ math.sqrt(561) * sh_7_5 * z
+ math.sqrt(561) * sh_7_9 * x
)
)
sh_8_6 = (
(1 / 16) * math.sqrt(34) * sh_7_10 * x
- 1 / 16 * math.sqrt(34) * sh_7_4 * z
+ (1 / 4) * math.sqrt(17) * sh_7_5 * y
+ (1 / 16) * math.sqrt(102) * sh_7_6 * z
+ (1 / 16) * math.sqrt(102) * sh_7_8 * x
)
sh_8_7 = (
-1 / 80 * math.sqrt(1190) * sh_7_5 * z
+ (1 / 40) * math.sqrt(1785) * sh_7_6 * y
+ (1 / 20) * math.sqrt(255) * sh_7_7 * x
+ (1 / 80) * math.sqrt(1190) * sh_7_9 * x
)
sh_8_8 = -1 / 60 * math.sqrt(1785) * sh_7_6 * x + (1 / 15) * math.sqrt(255) * sh_7_7 * y - 1 / 60 * math.sqrt(1785) * sh_7_8 * z
sh_8_9 = (
-1 / 80 * math.sqrt(1190) * sh_7_5 * x
+ (1 / 20) * math.sqrt(255) * sh_7_7 * z
+ (1 / 40) * math.sqrt(1785) * sh_7_8 * y
- 1 / 80 * math.sqrt(1190) * sh_7_9 * z
)
sh_8_10 = (
-1 / 16 * math.sqrt(34) * sh_7_10 * z
- 1 / 16 * math.sqrt(34) * sh_7_4 * x
- 1 / 16 * math.sqrt(102) * sh_7_6 * x
+ (1 / 16) * math.sqrt(102) * sh_7_8 * z
+ (1 / 4) * math.sqrt(17) * sh_7_9 * y
)
sh_8_11 = (
(1 / 48)
* math.sqrt(2)
* (
math.sqrt(1122) * sh_7_10 * y
- math.sqrt(102) * sh_7_11 * z
- math.sqrt(102) * sh_7_3 * x
- math.sqrt(561) * sh_7_5 * x
+ math.sqrt(561) * sh_7_9 * z
)
)
sh_8_12 = (
(1 / 40) * math.sqrt(935) * sh_7_10 * z
+ (1 / 10) * math.sqrt(85) * sh_7_11 * y
- 1 / 40 * math.sqrt(85) * sh_7_12 * z
- 1 / 40 * math.sqrt(85) * sh_7_2 * x
- 1 / 40 * math.sqrt(935) * sh_7_4 * x
)
sh_8_13 = (
(1 / 80)
* math.sqrt(2)
* (
-math.sqrt(85) * sh_7_1 * x
+ math.sqrt(2210) * sh_7_11 * z
+ math.sqrt(2210) * sh_7_12 * y
- math.sqrt(85) * sh_7_13 * z
- math.sqrt(2210) * sh_7_3 * x
)
)
sh_8_14 = (
-1 / 240 * math.sqrt(510) * sh_7_0 * x
+ (1 / 240) * math.sqrt(46410) * sh_7_12 * z
+ (1 / 60) * math.sqrt(1785) * sh_7_13 * y
- 1 / 240 * math.sqrt(510) * sh_7_14 * z
- 1 / 240 * math.sqrt(46410) * sh_7_2 * x
)
sh_8_15 = -1 / 16 * math.sqrt(238) * sh_7_1 * x + (1 / 16) * math.sqrt(238) * sh_7_13 * z + (1 / 8) * math.sqrt(17) * sh_7_14 * y
sh_8_16 = (1 / 4) * math.sqrt(17) * (-sh_7_0 * x + sh_7_14 * z)
if lmax == 8:
return torch.stack(
[
sh_1_0,
sh_1_1,
sh_1_2,
sh_2_0,
sh_2_1,
sh_2_2,
sh_2_3,
sh_2_4,
sh_3_0,
sh_3_1,
sh_3_2,
sh_3_3,
sh_3_4,
sh_3_5,
sh_3_6,
sh_4_0,
sh_4_1,
sh_4_2,
sh_4_3,
sh_4_4,
sh_4_5,
sh_4_6,
sh_4_7,
sh_4_8,
sh_5_0,
sh_5_1,
sh_5_2,
sh_5_3,
sh_5_4,
sh_5_5,
sh_5_6,
sh_5_7,
sh_5_8,
sh_5_9,
sh_5_10,
sh_6_0,
sh_6_1,
sh_6_2,
sh_6_3,
sh_6_4,
sh_6_5,
sh_6_6,
sh_6_7,
sh_6_8,
sh_6_9,
sh_6_10,
sh_6_11,
sh_6_12,
sh_7_0,
sh_7_1,
sh_7_2,
sh_7_3,
sh_7_4,
sh_7_5,
sh_7_6,
sh_7_7,
sh_7_8,
sh_7_9,
sh_7_10,
sh_7_11,
sh_7_12,
sh_7_13,
sh_7_14,
sh_8_0,
sh_8_1,
sh_8_2,
sh_8_3,
sh_8_4,
sh_8_5,
sh_8_6,
sh_8_7,
sh_8_8,
sh_8_9,
sh_8_10,
sh_8_11,
sh_8_12,
sh_8_13,
sh_8_14,
sh_8_15,
sh_8_16,
],
dim=-1,
)
def lmax_tensor_size(lmax):
return ((lmax + 1) ** 2) - 1
def get_split_sizes_from_dim(feature_dim):
"""
Find the lmax value and return split sizes for torch.split based on feature dimension.
Args:
feature_dim: The dimension of the feature (shape[1] of the tensor)
Returns:
split_sizes: A list of split sizes for torch.split (sizes of spherical harmonic components)
"""
lmax = 1
while lmax_tensor_size(lmax) < feature_dim:
lmax += 1
if lmax_tensor_size(lmax) != feature_dim:
raise ValueError(f"Feature dimension {feature_dim} does not correspond to a valid lmax value")
# Return the sizes of each spherical harmonic component
return [2 * l + 1 for l in range(1, lmax + 1)] # noqa: E741
class TensorLayerNorm(nn.Module):
def __init__(self, hidden_channels, trainable):
super(TensorLayerNorm, self).__init__()
self.hidden_channels = hidden_channels
self.eps = 1e-12
weight = torch.ones(self.hidden_channels)
if trainable:
self.register_parameter("weight", nn.Parameter(weight))
else:
self.register_buffer("weight", weight)
self.reset_parameters()
def reset_parameters(self):
weight = torch.ones(self.hidden_channels)
self.weight.data.copy_(weight)
def max_min_norm(self, tensor):
# Based on VisNet (https://www.nature.com/articles/s41467-023-43720-2)
dist = torch.norm(tensor, dim=1, keepdim=True)
if (dist == 0).all():
return torch.zeros_like(tensor)
dist = dist.clamp(min=self.eps)
direct = tensor / dist
max_val, _ = torch.max(dist, dim=-1)
min_val, _ = torch.min(dist, dim=-1)
delta = (max_val - min_val).view(-1)
delta = torch.where(delta == 0, torch.ones_like(delta), delta)
dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1)
return F.relu(dist) * direct
def forward(self, tensor):
# vec: (num_atoms, feature_dim, hidden_channels)
feature_dim = tensor.shape[1]
try:
split_sizes = get_split_sizes_from_dim(feature_dim)
except ValueError as e:
raise ValueError(f"VecLayerNorm received unsupported feature dimension {feature_dim}: {str(e)}")
# Split the vector into parts
vec_parts = torch.split(tensor, split_sizes, dim=1)
# Normalize each part separately
normalized_parts = [self.max_min_norm(part) for part in vec_parts]
# Concatenate the normalized parts
normalized_vec = torch.cat(normalized_parts, dim=1)
# Apply weight
return normalized_vec * self.weight.unsqueeze(0).unsqueeze(0)
def normalize_string(s: str) -> str:
return s.lower().replace("-", "").replace("_", "").replace(" ", "")
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
act_class_mapping = {"ssp": ShiftedSoftplus, "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "swish": Swish}
# https://github.com/sunglasses-ai/classy/blob/3e74cba1fdf1b9f9f2ba1cfcfa6c2017aa59fc04/classy/optim/factories.py#L14
def get_activations(optional=False, *args, **kwargs):
activations = {
normalize_string(act.__name__): act
for act in vars(torch.nn.modules.activation).values()
if isinstance(act, type) and issubclass(act, torch.nn.Module)
}
activations.update(
{
"relu": torch.nn.ReLU,
"elu": torch.nn.ELU,
"sigmoid": torch.nn.Sigmoid,
"silu": torch.nn.SiLU,
"mish": torch.nn.Mish,
"swish": torch.nn.SiLU,
"selu": torch.nn.SELU,
"scaled_swish": scaled_silu,
"softplus": shifted_softplus,
"slrelu": SmoothLeakyReLU,
}
)
if optional:
activations[""] = None
return activations
def get_activations_none(optional=False, *args, **kwargs):
activations = {
normalize_string(act.__name__): act
for act in vars(torch.nn.modules.activation).values()
if isinstance(act, type) and issubclass(act, torch.nn.Module)
}
activations.update(
{
"relu": torch.nn.ReLU,
"elu": torch.nn.ELU,
"sigmoid": torch.nn.Sigmoid,
"silu": torch.nn.SiLU,
"selu": torch.nn.SELU,
}
)
if optional:
activations[""] = None
activations[None] = None
return activations
def dictionary_to_option(options, selected):
if selected not in options:
raise ValueError(f'Invalid choice "{selected}", choose one from {", ".join(list(options.keys()))} ')
activation = options[selected]
if inspect.isclass(activation):
activation = activation()
return activation
def str2act(input_str, *args, **kwargs):
if input_str == "":
return None
act = get_activations(optional=True, *args, **kwargs)
out = dictionary_to_option(act, input_str)
return out
class ExpNormalSmearing(nn.Module):
def __init__(self, cutoff=5.0, n_rbf=50, trainable=False):
super(ExpNormalSmearing, self).__init__()
if isinstance(cutoff, torch.Tensor):
cutoff = cutoff.item()
self.cutoff = cutoff
self.n_rbf = n_rbf
self.trainable = trainable
self.cutoff_fn = CosineCutoff(cutoff)
self.alpha = 5.0 / cutoff
means, betas = self._initial_params()
if trainable:
self.register_parameter("means", nn.Parameter(means))
self.register_parameter("betas", nn.Parameter(betas))
else:
self.register_buffer("means", means)
self.register_buffer("betas", betas)
def _initial_params(self):
start_value = torch.exp(torch.scalar_tensor(-self.cutoff))
means = torch.linspace(start_value, 1, self.n_rbf)
betas = torch.tensor([(2 / self.n_rbf * (1 - start_value)) ** -2] * self.n_rbf)
return means, betas
def reset_parameters(self):
means, betas = self._initial_params()
self.means.data.copy_(means)
self.betas.data.copy_(betas)
def forward(self, dist):
dist = dist.unsqueeze(-1)
return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2)
def str2basis(input_str):
if type(input_str) != str: # noqa: E721
return input_str
if input_str == "BesselBasis":
radial_basis = BesselBasis
elif input_str == "GaussianRBF":
radial_basis = GaussianRBF
elif input_str.lower() == "expnorm":
radial_basis = ExpNormalSmearing
else:
raise ValueError("Unknown radial basis: {}".format(input_str))
return radial_basis
class MLP(nn.Module):
def __init__(
self,
hidden_dims: List[int],
bias=True,
activation=None,
last_activation=None,
weight_init=xavier_uniform_,
bias_init=zeros_initializer,
norm="",
):
super().__init__()
# hidden_dims = [hidden, half, hidden]
dims = hidden_dims
n_layers = len(dims)
DenseMLP = partial(Dense, bias=bias, weight_init=weight_init, bias_init=bias_init)
self.dense_layers = nn.ModuleList(
[DenseMLP(dims[i], dims[i + 1], activation=activation, norm=norm) for i in range(n_layers - 2)]
+ [DenseMLP(dims[-2], dims[-1], activation=last_activation)]
)
self.layers = nn.Sequential(*self.dense_layers)
self.reset_parameters()
def reset_parameters(self):
for m in self.dense_layers:
m.reset_parameters()
def forward(self, x):
return self.layers(x)
class NodeInit(MessagePassing):
def __init__(
self,
hidden_channels,
num_rbf,
cutoff,
max_z=100,
activation=F.silu,
proj_ln="",
last_activation=False,
weight_init=nn.init.xavier_uniform_,
bias_init=nn.init.zeros_,
concat=False,
):
super(NodeInit, self).__init__(aggr="add")
if type(hidden_channels) == int: # noqa: E721
hidden_channels = [hidden_channels]
first_channel = hidden_channels[0]
last_channel = hidden_channels[-1]
DenseInit = partial(Dense, weight_init=weight_init, bias_init=bias_init) # noqa: F841
self.concat = concat
self.embedding = nn.Embedding(max_z, last_channel)
if self.concat:
self.embedding_src = nn.Embedding(max_z, first_channel)
self.distance_proj = MLP(
[num_rbf + 2 * first_channel] + hidden_channels,
activation=activation,
norm=proj_ln,
weight_init=weight_init,
bias_init=bias_init,
last_activation=activation if last_activation else None,
)
else:
self.distance_proj = MLP(
[num_rbf] + [last_channel], activation=None, norm="", weight_init=weight_init, bias_init=bias_init, last_activation=None
)
if not self.concat:
self.combine = MLP(
[2 * last_channel] + hidden_channels,
activation=activation,
norm=proj_ln,
weight_init=weight_init,
bias_init=bias_init,
last_activation=activation if last_activation else None,
)
self.cutoff = CosineCutoff(cutoff)
self.reset_parameters()
def reset_parameters(self):
self.embedding.reset_parameters()
if self.concat:
self.embedding_src.reset_parameters()
self.distance_proj.reset_parameters()
if not self.concat:
self.combine.reset_parameters()
def forward(self, z, x, edge_index, edge_weight, edge_attr):
# remove self loops
mask = edge_index[0] != edge_index[1]
if not mask.all():
edge_index = edge_index[:, mask]
edge_weight = edge_weight[mask]
edge_attr = edge_attr[mask]
x_neighbors = self.embedding(z)
if not self.concat:
C = self.cutoff(edge_weight)
W = self.distance_proj(edge_attr) * C.view(-1, 1)
x_src = x_neighbors
else:
x_src = self.embedding_src(z)
W = edge_attr
# propagate_type: (x: Tensor, s:Tensor, W: Tensor)
x_neighbors = self.propagate(edge_index, x=x_neighbors, s=x_src, W=W, size=None)
if self.concat:
x_neighbors = x + x_neighbors
else:
x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1))
return x_neighbors
def message(self, s_i, x_j, W):
if self.concat:
return self.distance_proj(torch.cat([W, x_j, s_i], dim=1))
return x_j * W
class EdgeInit(MessagePassing):
def __init__(
self,
num_rbf,
hidden_channels,
activation=F.silu,
proj_ln="",
last_activation=False,
weight_init=nn.init.xavier_uniform_,
bias_init=nn.init.zeros_,
):
super(EdgeInit, self).__init__(aggr=None)
self.activation = activation
if type(hidden_channels) == int: # noqa: E721
hidden_channels = [hidden_channels]
self.edge_up = MLP(
[num_rbf] + hidden_channels,
activation=activation,
norm=proj_ln,
weight_init=weight_init,
bias_init=bias_init,
last_activation=activation if last_activation else None,
)
self.reset_parameters()
def reset_parameters(self):
self.edge_up.reset_parameters()
def forward(self, edge_index, edge_attr, x):
# propagate_type: (x: Tensor, edge_attr: Tensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
return out
def message(self, x_i, x_j, edge_attr):
return (x_i + x_j) * self.edge_up(edge_attr)
def aggregate(self, features, index):
# no aggregate
return features