Yuan (Cyrus) Chiang commited on
Commit
7cc6c4a
·
unverified ·
1 Parent(s): 5716d3b

Add convenient ZBL torch calculator (#44)

Browse files

* add optimization convergence info

* add zbl and test

mlip_arena/data/collate.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ # TODO: consider using vesin
5
+ from matscipy.neighbours import neighbour_list
6
+ from torch_geometric.data import Data
7
+
8
+ from ase import Atoms
9
+ from ase.calculators.singlepoint import SinglePointCalculator
10
+
11
+
12
+ def get_neighbor(
13
+ atoms: Atoms, cutoff: float, self_interaction: bool = False
14
+ ):
15
+ pbc = atoms.pbc
16
+ cell = atoms.cell.array
17
+
18
+ i, j, S = neighbour_list(
19
+ quantities="ijS",
20
+ pbc=pbc,
21
+ cell=cell,
22
+ positions=atoms.positions,
23
+ cutoff=cutoff
24
+ )
25
+
26
+ if not self_interaction:
27
+ # Eliminate self-edges that don't cross periodic boundaries
28
+ true_self_edge = i == j
29
+ true_self_edge &= np.all(S == 0, axis=1)
30
+ keep_edge = ~true_self_edge
31
+
32
+ i = i[keep_edge]
33
+ j = j[keep_edge]
34
+ S = S[keep_edge]
35
+
36
+ edge_index = np.stack((i, j)).astype(np.int64)
37
+ edge_shift = np.dot(S, cell)
38
+
39
+ return edge_index, edge_shift
40
+
41
+
42
+
43
+ def collate_fn(batch: list[Atoms], cutoff: float) -> Data:
44
+ """Collate a list of Atoms objects into a single batched Atoms object."""
45
+
46
+ # Offset the edge indices for each graph to ensure they remain disconnected
47
+ offset = 0
48
+
49
+ node_batch = []
50
+
51
+ numbers_batch = []
52
+ positions_batch = []
53
+ # ec_batch = []
54
+
55
+ forces_batch = []
56
+ charges_batch = []
57
+ magmoms_batch = []
58
+ dipoles_batch = []
59
+
60
+ edge_index_batch = []
61
+ edge_shift_batch = []
62
+
63
+ cell_batch = []
64
+ natoms_batch = []
65
+
66
+ energy_batch = []
67
+ stress_batch = []
68
+
69
+ for i, atoms in enumerate(batch):
70
+
71
+ edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False)
72
+
73
+ edge_index[0] += offset
74
+ edge_index[1] += offset
75
+ edge_index_batch.append(torch.tensor(edge_index))
76
+ edge_shift_batch.append(torch.tensor(edge_shift))
77
+
78
+ natoms = len(atoms)
79
+ offset += natoms
80
+ node_batch.append(torch.ones(natoms, dtype=torch.long) * i)
81
+ natoms_batch.append(natoms)
82
+
83
+ cell_batch.append(torch.tensor(atoms.cell.array))
84
+ numbers_batch.append(torch.tensor(atoms.numbers))
85
+ positions_batch.append(torch.tensor(atoms.positions))
86
+
87
+ # ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers])
88
+
89
+ charges_batch.append(
90
+ atoms.get_initial_charges()
91
+ if atoms.get_initial_charges().any()
92
+ else torch.full((natoms,), torch.nan)
93
+ )
94
+ magmoms_batch.append(
95
+ atoms.get_initial_magnetic_moments()
96
+ if atoms.get_initial_magnetic_moments().any()
97
+ else torch.full((natoms,), torch.nan)
98
+ )
99
+
100
+ # Create the new 'arrays' data for the batch
101
+
102
+ cell_batch = torch.stack(cell_batch, dim=0)
103
+ node_batch = torch.cat(node_batch, dim=0)
104
+ positions_batch = torch.cat(positions_batch, dim=0)
105
+ numbers_batch = torch.cat(numbers_batch, dim=0)
106
+ natoms_batch = torch.tensor(natoms_batch, dtype=torch.long)
107
+
108
+ charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None
109
+ magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None
110
+
111
+ # ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch))
112
+ # ec_batch = torch.stack(ec_batch, dim=0)
113
+
114
+ edge_index_batch = torch.cat(edge_index_batch, dim=1)
115
+ edge_shift_batch = torch.cat(edge_shift_batch, dim=0)
116
+
117
+ arrays_batch_concatenated = {
118
+ "cell": cell_batch,
119
+ "positions": positions_batch,
120
+ "edge_index": edge_index_batch,
121
+ "edge_shift": edge_shift_batch,
122
+ "numbers": numbers_batch,
123
+ "num_nodes": offset,
124
+ "batch": node_batch,
125
+ "charges": charges_batch,
126
+ "magmoms": magmoms_batch,
127
+ # "ec": ec_batch,
128
+ "natoms": natoms_batch,
129
+ "cutoff": torch.tensor(cutoff),
130
+ }
131
+
132
+ # TODO: custom fields
133
+
134
+ # Create a new Data object with the concatenated arrays data
135
+ batch_data = Data.from_dict(arrays_batch_concatenated)
136
+
137
+ return batch_data
138
+
139
+
140
+ def decollate_fn(batch_data: Data) -> list[Atoms]:
141
+ """Decollate a batched Data object into a list of individual Atoms objects."""
142
+
143
+ # FIXME: this function is not working properly when the batch_data is on GPU.
144
+ # TODO: create a new Cell class using torch tensor to handle device placement.
145
+ # As a temporary fix, detach the batch_data from the GPU and move it to CPU.
146
+ batch_data = batch_data.detach().cpu()
147
+
148
+ # Initialize empty lists to store individual data entries
149
+ individual_entries = []
150
+
151
+ # Split the 'batch' attribute to identify data entries
152
+ unique_batches = batch_data.batch.unique(sorted=True)
153
+
154
+ for i in unique_batches:
155
+ # Identify the indices corresponding to the current data entry
156
+ entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0]
157
+
158
+ # Extract the attributes for the current data entry
159
+ cell = batch_data.cell[i]
160
+ numbers = batch_data.numbers[entry_indices]
161
+ positions = batch_data.positions[entry_indices]
162
+ # edge_index = batch_data.edge_index[:, entry_indices]
163
+ # edge_shift = batch_data.edge_shift[entry_indices]
164
+ # batch_data.ec[entry_indices] if batch_data.ec is not None else None
165
+
166
+ # Optional fields
167
+ energy = batch_data.energy[i] if "energy" in batch_data else None
168
+ forces = batch_data.forces[entry_indices] if "forces" in batch_data else None
169
+ stress = batch_data.stress[i] if "stress" in batch_data else None
170
+
171
+ # charges = batch_data.charges[entry_indices] if "charges" in batch_data else None
172
+ # magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None
173
+ # dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None
174
+
175
+ # TODO: cumstom fields
176
+
177
+ # Create an 'Atoms' object for the current data entry
178
+ atoms = Atoms(
179
+ cell=cell,
180
+ positions=positions,
181
+ numbers=numbers,
182
+ # forces=None if torch.any(torch.isnan(forces)) else forces,
183
+ # charges=None if torch.any(torch.isnan(charges)) else charges,
184
+ # magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms,
185
+ # dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles,
186
+ # energy=None if torch.isnan(energy) else energy,
187
+ # stress=None if torch.any(torch.isnan(stress)) else stress,
188
+ )
189
+
190
+ atoms.calc = SinglePointCalculator(
191
+ energy=energy,
192
+ forces=forces,
193
+ stress=stress,
194
+ # charges=charges,
195
+ # magmoms=magmoms,
196
+ ) # type: ignore
197
+
198
+ # Append the individual data entry to the list
199
+ individual_entries.append(atoms)
200
+
201
+ return individual_entries
mlip_arena/models/__init__.py CHANGED
@@ -6,11 +6,21 @@ from pathlib import Path
6
 
7
  import torch
8
  import yaml
9
- from ase import Atoms
10
- from ase.calculators.calculator import Calculator, all_changes
11
  from huggingface_hub import PyTorchModelHubMixin
12
  from torch import nn
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # from torch_geometric.data import Data
15
 
16
  with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
@@ -20,14 +30,17 @@ MLIPMap = {}
20
 
21
  for model, metadata in REGISTRY.items():
22
  try:
23
- module = importlib.import_module(f"{__package__}.{metadata['module']}.{metadata['family']}")
 
 
24
  MLIPMap[model] = getattr(module, metadata["class"])
25
  except (ModuleNotFoundError, AttributeError, ValueError) as e:
26
- print(e)
27
  continue
28
 
29
  MLIPEnum = Enum("MLIPEnum", MLIPMap)
30
 
 
31
  class MLIP(
32
  nn.Module,
33
  PyTorchModelHubMixin,
@@ -35,6 +48,9 @@ class MLIP(
35
  ):
36
  def __init__(self, model: nn.Module) -> None:
37
  super().__init__()
 
 
 
38
  self.model = model
39
 
40
  def forward(self, x):
@@ -47,7 +63,9 @@ class MLIPCalculator(MLIP, Calculator):
47
 
48
  def __init__(
49
  self,
50
- model,
 
 
51
  # ASE Calculator
52
  restart=None,
53
  atoms=None,
@@ -60,12 +78,24 @@ class MLIPCalculator(MLIP, Calculator):
60
  ) # Initialize ASE Calculator part
61
  # Additional initialization if needed
62
  # self.name: str = self.__class__.__name__
 
 
 
63
  # self.device = device or torch.device(
64
  # "cuda" if torch.cuda.is_available() else "cpu"
65
  # )
66
  # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
67
  # self.implemented_properties = ["energy", "forces", "stress"]
68
 
 
 
 
 
 
 
 
 
 
69
  def calculate(
70
  self,
71
  atoms: Atoms,
@@ -75,7 +105,11 @@ class MLIPCalculator(MLIP, Calculator):
75
  """Calculate energies and forces for the given Atoms object"""
76
  super().calculate(atoms, properties, system_changes)
77
 
78
- output = self.forward(atoms)
 
 
 
 
79
 
80
  self.results = {}
81
  if "energy" in properties:
@@ -85,13 +119,14 @@ class MLIPCalculator(MLIP, Calculator):
85
  if "stress" in properties:
86
  self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
87
 
88
- def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
89
- """Implement data conversion, graph creation, and model forward pass
 
 
 
 
 
90
 
91
- Example implementation:
92
- 1. Use `ase.neighborlist.NeighborList` to get neighbor list
93
- 2. Create `torch_geometric.data.Data` object and copy the data
94
- 3. Pass the `Data` object to the model and return the output
95
 
96
- """
97
- raise NotImplementedError
 
6
 
7
  import torch
8
  import yaml
 
 
9
  from huggingface_hub import PyTorchModelHubMixin
10
  from torch import nn
11
 
12
+ from ase import Atoms
13
+ from ase.calculators.calculator import Calculator, all_changes
14
+ from mlip_arena.data.collate import collate_fn
15
+ from mlip_arena.models.utils import get_freer_device
16
+
17
+ try:
18
+ from prefect.logging import get_run_logger
19
+
20
+ logger = get_run_logger()
21
+ except (ImportError, RuntimeError):
22
+ from loguru import logger
23
+
24
  # from torch_geometric.data import Data
25
 
26
  with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
 
30
 
31
  for model, metadata in REGISTRY.items():
32
  try:
33
+ module = importlib.import_module(
34
+ f"{__package__}.{metadata['module']}.{metadata['family']}"
35
+ )
36
  MLIPMap[model] = getattr(module, metadata["class"])
37
  except (ModuleNotFoundError, AttributeError, ValueError) as e:
38
+ logger.warning(e)
39
  continue
40
 
41
  MLIPEnum = Enum("MLIPEnum", MLIPMap)
42
 
43
+
44
  class MLIP(
45
  nn.Module,
46
  PyTorchModelHubMixin,
 
48
  ):
49
  def __init__(self, model: nn.Module) -> None:
50
  super().__init__()
51
+ # https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
52
+ # torch._dynamo.config.compiled_autograd = True
53
+ # self.model = torch.compile(model)
54
  self.model = model
55
 
56
  def forward(self, x):
 
63
 
64
  def __init__(
65
  self,
66
+ model: nn.Module,
67
+ device: torch.device | None = None,
68
+ cutoff: float = 6.0,
69
  # ASE Calculator
70
  restart=None,
71
  atoms=None,
 
78
  ) # Initialize ASE Calculator part
79
  # Additional initialization if needed
80
  # self.name: str = self.__class__.__name__
81
+ self.device = device or get_freer_device()
82
+ self.cutoff = cutoff
83
+ self.model.to(self.device)
84
  # self.device = device or torch.device(
85
  # "cuda" if torch.cuda.is_available() else "cpu"
86
  # )
87
  # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
88
  # self.implemented_properties = ["energy", "forces", "stress"]
89
 
90
+ # def __getstate__(self):
91
+ # state = self.__dict__.copy()
92
+ # state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
93
+ # return state
94
+
95
+ # def __setstate__(self, state):
96
+ # self.__dict__.update(state)
97
+ # self.model = torch.compile(state["_modules"]["model"])
98
+
99
  def calculate(
100
  self,
101
  atoms: Atoms,
 
105
  """Calculate energies and forces for the given Atoms object"""
106
  super().calculate(atoms, properties, system_changes)
107
 
108
+ # TODO: move collate_fn to here in MLIPCalculator
109
+ data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
110
+ output = self.forward(data)
111
+
112
+ # TODO: decollate_fn
113
 
114
  self.results = {}
115
  if "energy" in properties:
 
119
  if "stress" in properties:
120
  self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
121
 
122
+ # def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
123
+ # """Implement data conversion, graph creation, and model forward pass
124
+
125
+ # Example implementation:
126
+ # 1. Use `ase.neighborlist.NeighborList` to get neighbor list
127
+ # 2. Create `torch_geometric.data.Data` object and copy the data
128
+ # 3. Pass the `Data` object to the model and return the output
129
 
130
+ # """
 
 
 
131
 
132
+ # raise NotImplementedError
 
mlip_arena/models/classicals/__init__.py ADDED
File without changes
mlip_arena/models/classicals/zbl.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.linalg as LA
3
+ import torch.nn as nn
4
+ import torch_scatter
5
+ from torch_geometric.data import Data
6
+
7
+ from ase.data import covalent_radii
8
+ from ase.units import _e, _eps0, m, pi
9
+ from e3nn.util.jit import compile_mode # TODO: e3nn allows autograd in compiled model
10
+
11
+
12
+ @compile_mode("script")
13
+ class ZBL(nn.Module):
14
+ """Ziegler-Biersack-Littmark (ZBL) screened nuclear repulsion"""
15
+
16
+ def __init__(
17
+ self,
18
+ trianable: bool = False,
19
+ **kwargs,
20
+ ) -> None:
21
+ nn.Module.__init__(self, **kwargs)
22
+
23
+ torch.set_default_dtype(torch.double)
24
+
25
+ self.a = torch.nn.parameter.Parameter(
26
+ torch.tensor(
27
+ [0.18175, 0.50986, 0.28022, 0.02817], dtype=torch.get_default_dtype()
28
+ ),
29
+ requires_grad=trianable,
30
+ )
31
+ self.b = torch.nn.parameter.Parameter(
32
+ torch.tensor(
33
+ [-3.19980, -0.94229, -0.40290, -0.20162],
34
+ dtype=torch.get_default_dtype(),
35
+ ),
36
+ requires_grad=trianable,
37
+ )
38
+
39
+ self.a0 = torch.nn.parameter.Parameter(
40
+ torch.tensor(0.46850, dtype=torch.get_default_dtype()),
41
+ requires_grad=trianable,
42
+ )
43
+
44
+ self.p = torch.nn.parameter.Parameter(
45
+ torch.tensor(0.23, dtype=torch.get_default_dtype()), requires_grad=trianable
46
+ )
47
+
48
+ self.register_buffer(
49
+ "covalent_radii",
50
+ torch.tensor(
51
+ covalent_radii,
52
+ dtype=torch.get_default_dtype(),
53
+ ),
54
+ )
55
+
56
+ def phi(self, x):
57
+ return torch.einsum("i,ij->j", self.a, torch.exp(torch.outer(self.b, x)))
58
+
59
+ def d_phi(self, x):
60
+ return torch.einsum(
61
+ "i,ij->j", self.a * self.b, torch.exp(torch.outer(self.b, x))
62
+ )
63
+
64
+ def dd_phi(self, x):
65
+ return torch.einsum(
66
+ "i,ij->j", self.a * self.b**2, torch.exp(torch.outer(self.b, x))
67
+ )
68
+
69
+ def eij(
70
+ self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor
71
+ ) -> torch.Tensor: # [eV]
72
+ return _e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij)
73
+
74
+ def d_eij(
75
+ self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor
76
+ ) -> torch.Tensor: # [eV / A]
77
+ return -_e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**2)
78
+
79
+ def dd_eij(
80
+ self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor
81
+ ) -> torch.Tensor: # [eV / A^2]
82
+ return _e * m / (2 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**3)
83
+
84
+ def switch_fn(
85
+ self,
86
+ zi: torch.Tensor,
87
+ zj: torch.Tensor,
88
+ rij: torch.Tensor,
89
+ aij: torch.Tensor,
90
+ router: torch.Tensor,
91
+ rinner: torch.Tensor,
92
+ ) -> torch.Tensor: # [eV]
93
+ # aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p))
94
+
95
+ xrouter = router / aij
96
+
97
+ energy = self.eij(zi, zj, router) * self.phi(xrouter)
98
+
99
+ grad1 = self.d_eij(zi, zj, router) * self.phi(xrouter) + self.eij(
100
+ zi, zj, router
101
+ ) * self.d_phi(xrouter)
102
+
103
+ grad2 = (
104
+ self.dd_eij(zi, zj, router) * self.phi(xrouter)
105
+ + self.d_eij(zi, zj, router) * self.d_phi(xrouter)
106
+ + self.d_eij(zi, zj, router) * self.d_phi(xrouter)
107
+ + self.eij(zi, zj, router) * self.dd_phi(xrouter)
108
+ )
109
+
110
+ A = (-3 * grad1 + (router - rinner) * grad2) / (router - rinner) ** 2
111
+ B = (2 * grad1 - (router - rinner) * grad2) / (router - rinner) ** 3
112
+ C = (
113
+ -energy
114
+ + 1.0 / 2.0 * (router - rinner) * grad1
115
+ - 1.0 / 12.0 * (router - rinner) ** 2 * grad2
116
+ )
117
+
118
+ switching = torch.where(
119
+ rij < rinner,
120
+ C,
121
+ A / 3.0 * (rij - rinner) ** 3 + B / 4.0 * (rij - rinner) ** 4 + C,
122
+ )
123
+
124
+ return switching
125
+
126
+ def envelope(self, r: torch.Tensor, rc: torch.Tensor, p: int = 6):
127
+ x = r / rc
128
+ y = (
129
+ 1.0
130
+ - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p)
131
+ + p * (p + 2.0) * torch.pow(x, p + 1)
132
+ - (p * (p + 1.0) / 2) * torch.pow(x, p + 2)
133
+ ) * (x < 1)
134
+ return y
135
+
136
+ def _get_derivatives(self, energy: torch.Tensor, data: Data):
137
+ egradi, egradij = torch.autograd.grad(
138
+ outputs=[energy], # TODO: generalized derivatives
139
+ inputs=[data.positions, data.vij], # TODO: generalized derivatives
140
+ grad_outputs=[torch.ones_like(energy)],
141
+ retain_graph=True,
142
+ create_graph=True,
143
+ allow_unused=True,
144
+ )
145
+
146
+ volume = torch.det(data.cell) # (batch,)
147
+ rfaxy = torch.einsum("ax,ay->axy", data.vij, -egradij)
148
+
149
+ edge_batch = data.batch[data.edge_index[0]]
150
+
151
+ stress = (
152
+ -0.5
153
+ * torch_scatter.scatter_sum(rfaxy, edge_batch, dim=0)
154
+ / volume.view(-1, 1)
155
+ )
156
+
157
+ return -egradi, stress
158
+
159
+ def forward(
160
+ self,
161
+ data: Data,
162
+ ) -> dict[str, torch.Tensor]:
163
+ # TODO: generalized derivatives
164
+ data.positions.requires_grad_(True)
165
+
166
+ numbers = data.numbers # (sum(N), )
167
+ positions = data.positions # (sum(N), 3)
168
+ edge_index = data.edge_index # (2, sum(E))
169
+ edge_shift = data.edge_shift # (sum(E), 3)
170
+ batch = data.batch # (sum(N), )
171
+
172
+ edge_src, edge_dst = edge_index[0], edge_index[1]
173
+
174
+ if "rij" not in data or "vij" not in data:
175
+ data.vij = positions[edge_dst] - positions[edge_src] + edge_shift
176
+ data.rij = LA.norm(data.vij, dim=-1)
177
+
178
+ rbond = (
179
+ self.covalent_radii[numbers[edge_src]]
180
+ + self.covalent_radii[numbers[edge_dst]]
181
+ )
182
+
183
+ rij = data.rij
184
+ zi = numbers[edge_src] # (sum(E), )
185
+ zj = numbers[edge_dst] # (sum(E), )
186
+
187
+ aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) # (sum(E), )
188
+
189
+ energy_pairs = (
190
+ self.eij(zi, zj, rij)
191
+ * self.phi(rij / aij.to(rij))
192
+ * self.envelope(rij, torch.min(data.cutoff, rbond))
193
+ )
194
+
195
+ energy_nodes = 0.5 * torch_scatter.scatter_add(
196
+ src=energy_pairs,
197
+ index=edge_dst,
198
+ dim=0,
199
+ ) # (sum(N), )
200
+
201
+ energies = torch_scatter.scatter_add(
202
+ src=energy_nodes,
203
+ index=batch,
204
+ dim=0,
205
+ ) # (B, )
206
+
207
+ # TODO: generalized derivatives
208
+ forces, stress = self._get_derivatives(energies, data)
209
+
210
+ return {
211
+ "energy": energies,
212
+ "forces": forces,
213
+ "stress": stress,
214
+ }
mlip_arena/models/registry.yaml CHANGED
@@ -84,6 +84,7 @@ MatterSim:
84
  - eos_alloy
85
  gpu-tasks:
86
  - homonuclear-diatomics
 
87
  github: https://github.com/microsoft/mattersim
88
  doi: https://arxiv.org/abs/2405.04967
89
  date: 2024-12-05
@@ -264,6 +265,7 @@ ALIGNN:
264
  - MP22
265
  gpu-tasks:
266
  - homonuclear-diatomics
 
267
  # - combustion
268
  prediction: EFS
269
  nvt: true
@@ -309,6 +311,7 @@ ORBv2:
309
  gpu-tasks:
310
  - homonuclear-diatomics
311
  - combustion
 
312
  github: https://github.com/orbital-materials/orb-models
313
  doi:
314
  date: 2024-10-15
 
84
  - eos_alloy
85
  gpu-tasks:
86
  - homonuclear-diatomics
87
+ - stability
88
  github: https://github.com/microsoft/mattersim
89
  doi: https://arxiv.org/abs/2405.04967
90
  date: 2024-12-05
 
265
  - MP22
266
  gpu-tasks:
267
  - homonuclear-diatomics
268
+ - stability
269
  # - combustion
270
  prediction: EFS
271
  nvt: true
 
311
  gpu-tasks:
312
  - homonuclear-diatomics
313
  - combustion
314
+ - stability
315
  github: https://github.com/orbital-materials/orb-models
316
  doi:
317
  date: 2024-10-15
mlip_arena/tasks/optimize.py CHANGED
@@ -111,6 +111,9 @@ def run(
111
  logger.info(f"Criterion: {pformat(criterion)}")
112
  optimizer_instance.run(**criterion)
113
 
 
114
  return {
115
  "atoms": atoms,
 
 
116
  }
 
111
  logger.info(f"Criterion: {pformat(criterion)}")
112
  optimizer_instance.run(**criterion)
113
 
114
+
115
  return {
116
  "atoms": atoms,
117
+ "steps": optimizer_instance.nsteps,
118
+ "converged": optimizer_instance.converged(),
119
  }
tests/test_internal_calculators.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from mlip_arena.models import MLIPCalculator
3
+ from mlip_arena.models.classicals.zbl import ZBL
4
+
5
+ from ase.build import bulk
6
+
7
+
8
+ def test_zbl():
9
+ calc = MLIPCalculator(model=ZBL(), cutoff=6.0)
10
+
11
+ energies = []
12
+ forces = []
13
+ stresses = []
14
+
15
+ lattice_constants = [1, 3, 5, 7]
16
+
17
+ for a in lattice_constants:
18
+ atoms = bulk("Cu", "fcc", a=a) * (2, 2, 2)
19
+ atoms.calc = calc
20
+
21
+ energies.append(atoms.get_potential_energy())
22
+ forces.append(atoms.get_forces())
23
+ stresses.append(atoms.get_stress(voigt=False))
24
+
25
+ # test energy monotonicity
26
+ assert all(np.diff(energies) <= 0), "Energy is not monotonically decreasing with increasing lattice constant"
27
+
28
+ # test force vectors are all zeros due to symmetry
29
+ for f in forces:
30
+ assert np.allclose(f, 0), "Forces should be zero due to symmetry"
31
+
32
+ # test trace of stress is monotonically increasing (less negative) and zero beyond cutoff
33
+ traces = [np.trace(s) for s in stresses]
34
+
35
+ assert all(np.diff(traces) >= 0), "Trace of stress is not monotonically increasing with increasing lattice constant"
36
+ assert np.allclose(stresses[-1], 0), "Stress should be zero beyond cutoff"