import torch import torch.linalg as LA import torch.nn as nn import torch_scatter from torch_geometric.data import Data from ase.data import covalent_radii from ase.units import _e, _eps0, m, pi from e3nn.util.jit import compile_mode # TODO: e3nn allows autograd in compiled model @compile_mode("script") class ZBL(nn.Module): """Ziegler-Biersack-Littmark (ZBL) screened nuclear repulsion""" def __init__( self, trianable: bool = False, **kwargs, ) -> None: nn.Module.__init__(self, **kwargs) torch.set_default_dtype(torch.double) self.a = torch.nn.parameter.Parameter( torch.tensor( [0.18175, 0.50986, 0.28022, 0.02817], dtype=torch.get_default_dtype() ), requires_grad=trianable, ) self.b = torch.nn.parameter.Parameter( torch.tensor( [-3.19980, -0.94229, -0.40290, -0.20162], dtype=torch.get_default_dtype(), ), requires_grad=trianable, ) self.a0 = torch.nn.parameter.Parameter( torch.tensor(0.46850, dtype=torch.get_default_dtype()), requires_grad=trianable, ) self.p = torch.nn.parameter.Parameter( torch.tensor(0.23, dtype=torch.get_default_dtype()), requires_grad=trianable ) self.register_buffer( "covalent_radii", torch.tensor( covalent_radii, dtype=torch.get_default_dtype(), ), ) def phi(self, x): return torch.einsum("i,ij->j", self.a, torch.exp(torch.outer(self.b, x))) def d_phi(self, x): return torch.einsum( "i,ij->j", self.a * self.b, torch.exp(torch.outer(self.b, x)) ) def dd_phi(self, x): return torch.einsum( "i,ij->j", self.a * self.b**2, torch.exp(torch.outer(self.b, x)) ) def eij( self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor ) -> torch.Tensor: # [eV] return _e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij) def d_eij( self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor ) -> torch.Tensor: # [eV / A] return -_e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**2) def dd_eij( self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor ) -> torch.Tensor: # [eV / A^2] return _e * m / (2 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**3) def switch_fn( self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor, aij: torch.Tensor, router: torch.Tensor, rinner: torch.Tensor, ) -> torch.Tensor: # [eV] # aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) xrouter = router / aij energy = self.eij(zi, zj, router) * self.phi(xrouter) grad1 = self.d_eij(zi, zj, router) * self.phi(xrouter) + self.eij( zi, zj, router ) * self.d_phi(xrouter) grad2 = ( self.dd_eij(zi, zj, router) * self.phi(xrouter) + self.d_eij(zi, zj, router) * self.d_phi(xrouter) + self.d_eij(zi, zj, router) * self.d_phi(xrouter) + self.eij(zi, zj, router) * self.dd_phi(xrouter) ) A = (-3 * grad1 + (router - rinner) * grad2) / (router - rinner) ** 2 B = (2 * grad1 - (router - rinner) * grad2) / (router - rinner) ** 3 C = ( -energy + 1.0 / 2.0 * (router - rinner) * grad1 - 1.0 / 12.0 * (router - rinner) ** 2 * grad2 ) switching = torch.where( rij < rinner, C, A / 3.0 * (rij - rinner) ** 3 + B / 4.0 * (rij - rinner) ** 4 + C, ) return switching def envelope(self, r: torch.Tensor, rc: torch.Tensor, p: int = 6): x = r / rc y = ( 1.0 - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p) + p * (p + 2.0) * torch.pow(x, p + 1) - (p * (p + 1.0) / 2) * torch.pow(x, p + 2) ) * (x < 1) return y def _get_derivatives(self, energy: torch.Tensor, data: Data): egradi, egradij = torch.autograd.grad( outputs=[energy], # TODO: generalized derivatives inputs=[data.positions, data.vij], # TODO: generalized derivatives grad_outputs=[torch.ones_like(energy)], retain_graph=True, create_graph=True, allow_unused=True, ) volume = torch.det(data.cell) # (batch,) rfaxy = torch.einsum("ax,ay->axy", data.vij, -egradij) edge_batch = data.batch[data.edge_index[0]] stress = ( -0.5 * torch_scatter.scatter_sum(rfaxy, edge_batch, dim=0) / volume.view(-1, 1) ) return -egradi, stress def forward( self, data: Data, ) -> dict[str, torch.Tensor]: # TODO: generalized derivatives data.positions.requires_grad_(True) numbers = data.numbers # (sum(N), ) positions = data.positions # (sum(N), 3) edge_index = data.edge_index # (2, sum(E)) edge_shift = data.edge_shift # (sum(E), 3) batch = data.batch # (sum(N), ) edge_src, edge_dst = edge_index[0], edge_index[1] if "rij" not in data or "vij" not in data: data.vij = positions[edge_dst] - positions[edge_src] + edge_shift data.rij = LA.norm(data.vij, dim=-1) rbond = ( self.covalent_radii[numbers[edge_src]] + self.covalent_radii[numbers[edge_dst]] ) rij = data.rij zi = numbers[edge_src] # (sum(E), ) zj = numbers[edge_dst] # (sum(E), ) aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) # (sum(E), ) energy_pairs = ( self.eij(zi, zj, rij) * self.phi(rij / aij.to(rij)) * self.envelope(rij, torch.min(data.cutoff, rbond)) ) energy_nodes = 0.5 * torch_scatter.scatter_add( src=energy_pairs, index=edge_dst, dim=0, ) # (sum(N), ) energies = torch_scatter.scatter_add( src=energy_nodes, index=batch, dim=0, ) # (B, ) # TODO: generalized derivatives forces, stress = self._get_derivatives(energies, data) return { "energy": energies, "forces": forces, "stress": stress, }