Spaces:
Running
Running
import torch | |
import torch.linalg as LA | |
import torch.nn as nn | |
import torch_scatter | |
from torch_geometric.data import Data | |
from ase.data import covalent_radii | |
from ase.units import _e, _eps0, m, pi | |
from e3nn.util.jit import compile_mode # TODO: e3nn allows autograd in compiled model | |
class ZBL(nn.Module): | |
"""Ziegler-Biersack-Littmark (ZBL) screened nuclear repulsion""" | |
def __init__( | |
self, | |
trianable: bool = False, | |
**kwargs, | |
) -> None: | |
nn.Module.__init__(self, **kwargs) | |
torch.set_default_dtype(torch.double) | |
self.a = torch.nn.parameter.Parameter( | |
torch.tensor( | |
[0.18175, 0.50986, 0.28022, 0.02817], dtype=torch.get_default_dtype() | |
), | |
requires_grad=trianable, | |
) | |
self.b = torch.nn.parameter.Parameter( | |
torch.tensor( | |
[-3.19980, -0.94229, -0.40290, -0.20162], | |
dtype=torch.get_default_dtype(), | |
), | |
requires_grad=trianable, | |
) | |
self.a0 = torch.nn.parameter.Parameter( | |
torch.tensor(0.46850, dtype=torch.get_default_dtype()), | |
requires_grad=trianable, | |
) | |
self.p = torch.nn.parameter.Parameter( | |
torch.tensor(0.23, dtype=torch.get_default_dtype()), requires_grad=trianable | |
) | |
self.register_buffer( | |
"covalent_radii", | |
torch.tensor( | |
covalent_radii, | |
dtype=torch.get_default_dtype(), | |
), | |
) | |
def phi(self, x): | |
return torch.einsum("i,ij->j", self.a, torch.exp(torch.outer(self.b, x))) | |
def d_phi(self, x): | |
return torch.einsum( | |
"i,ij->j", self.a * self.b, torch.exp(torch.outer(self.b, x)) | |
) | |
def dd_phi(self, x): | |
return torch.einsum( | |
"i,ij->j", self.a * self.b**2, torch.exp(torch.outer(self.b, x)) | |
) | |
def eij( | |
self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor | |
) -> torch.Tensor: # [eV] | |
return _e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij) | |
def d_eij( | |
self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor | |
) -> torch.Tensor: # [eV / A] | |
return -_e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**2) | |
def dd_eij( | |
self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor | |
) -> torch.Tensor: # [eV / A^2] | |
return _e * m / (2 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**3) | |
def switch_fn( | |
self, | |
zi: torch.Tensor, | |
zj: torch.Tensor, | |
rij: torch.Tensor, | |
aij: torch.Tensor, | |
router: torch.Tensor, | |
rinner: torch.Tensor, | |
) -> torch.Tensor: # [eV] | |
# aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) | |
xrouter = router / aij | |
energy = self.eij(zi, zj, router) * self.phi(xrouter) | |
grad1 = self.d_eij(zi, zj, router) * self.phi(xrouter) + self.eij( | |
zi, zj, router | |
) * self.d_phi(xrouter) | |
grad2 = ( | |
self.dd_eij(zi, zj, router) * self.phi(xrouter) | |
+ self.d_eij(zi, zj, router) * self.d_phi(xrouter) | |
+ self.d_eij(zi, zj, router) * self.d_phi(xrouter) | |
+ self.eij(zi, zj, router) * self.dd_phi(xrouter) | |
) | |
A = (-3 * grad1 + (router - rinner) * grad2) / (router - rinner) ** 2 | |
B = (2 * grad1 - (router - rinner) * grad2) / (router - rinner) ** 3 | |
C = ( | |
-energy | |
+ 1.0 / 2.0 * (router - rinner) * grad1 | |
- 1.0 / 12.0 * (router - rinner) ** 2 * grad2 | |
) | |
switching = torch.where( | |
rij < rinner, | |
C, | |
A / 3.0 * (rij - rinner) ** 3 + B / 4.0 * (rij - rinner) ** 4 + C, | |
) | |
return switching | |
def envelope(self, r: torch.Tensor, rc: torch.Tensor, p: int = 6): | |
x = r / rc | |
y = ( | |
1.0 | |
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p) | |
+ p * (p + 2.0) * torch.pow(x, p + 1) | |
- (p * (p + 1.0) / 2) * torch.pow(x, p + 2) | |
) * (x < 1) | |
return y | |
def _get_derivatives(self, energy: torch.Tensor, data: Data): | |
egradi, egradij = torch.autograd.grad( | |
outputs=[energy], # TODO: generalized derivatives | |
inputs=[data.positions, data.vij], # TODO: generalized derivatives | |
grad_outputs=[torch.ones_like(energy)], | |
retain_graph=True, | |
create_graph=True, | |
allow_unused=True, | |
) | |
volume = torch.det(data.cell) # (batch,) | |
rfaxy = torch.einsum("ax,ay->axy", data.vij, -egradij) | |
edge_batch = data.batch[data.edge_index[0]] | |
stress = ( | |
-0.5 | |
* torch_scatter.scatter_sum(rfaxy, edge_batch, dim=0) | |
/ volume.view(-1, 1) | |
) | |
return -egradi, stress | |
def forward( | |
self, | |
data: Data, | |
) -> dict[str, torch.Tensor]: | |
# TODO: generalized derivatives | |
data.positions.requires_grad_(True) | |
numbers = data.numbers # (sum(N), ) | |
positions = data.positions # (sum(N), 3) | |
edge_index = data.edge_index # (2, sum(E)) | |
edge_shift = data.edge_shift # (sum(E), 3) | |
batch = data.batch # (sum(N), ) | |
edge_src, edge_dst = edge_index[0], edge_index[1] | |
if "rij" not in data or "vij" not in data: | |
data.vij = positions[edge_dst] - positions[edge_src] + edge_shift | |
data.rij = LA.norm(data.vij, dim=-1) | |
rbond = ( | |
self.covalent_radii[numbers[edge_src]] | |
+ self.covalent_radii[numbers[edge_dst]] | |
) | |
rij = data.rij | |
zi = numbers[edge_src] # (sum(E), ) | |
zj = numbers[edge_dst] # (sum(E), ) | |
aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) # (sum(E), ) | |
energy_pairs = ( | |
self.eij(zi, zj, rij) | |
* self.phi(rij / aij.to(rij)) | |
* self.envelope(rij, torch.min(data.cutoff, rbond)) | |
) | |
energy_nodes = 0.5 * torch_scatter.scatter_add( | |
src=energy_pairs, | |
index=edge_dst, | |
dim=0, | |
) # (sum(N), ) | |
energies = torch_scatter.scatter_add( | |
src=energy_nodes, | |
index=batch, | |
dim=0, | |
) # (B, ) | |
# TODO: generalized derivatives | |
forces, stress = self._get_derivatives(energies, data) | |
return { | |
"energy": energies, | |
"forces": forces, | |
"stress": stress, | |
} | |