Spaces:
Running
Running
Yuan (Cyrus) Chiang
commited on
Add convenient ZBL torch calculator (#44)
Browse files* add optimization convergence info
* add zbl and test
mlip_arena/data/collate.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# TODO: consider using vesin
|
5 |
+
from matscipy.neighbours import neighbour_list
|
6 |
+
from torch_geometric.data import Data
|
7 |
+
|
8 |
+
from ase import Atoms
|
9 |
+
from ase.calculators.singlepoint import SinglePointCalculator
|
10 |
+
|
11 |
+
|
12 |
+
def get_neighbor(
|
13 |
+
atoms: Atoms, cutoff: float, self_interaction: bool = False
|
14 |
+
):
|
15 |
+
pbc = atoms.pbc
|
16 |
+
cell = atoms.cell.array
|
17 |
+
|
18 |
+
i, j, S = neighbour_list(
|
19 |
+
quantities="ijS",
|
20 |
+
pbc=pbc,
|
21 |
+
cell=cell,
|
22 |
+
positions=atoms.positions,
|
23 |
+
cutoff=cutoff
|
24 |
+
)
|
25 |
+
|
26 |
+
if not self_interaction:
|
27 |
+
# Eliminate self-edges that don't cross periodic boundaries
|
28 |
+
true_self_edge = i == j
|
29 |
+
true_self_edge &= np.all(S == 0, axis=1)
|
30 |
+
keep_edge = ~true_self_edge
|
31 |
+
|
32 |
+
i = i[keep_edge]
|
33 |
+
j = j[keep_edge]
|
34 |
+
S = S[keep_edge]
|
35 |
+
|
36 |
+
edge_index = np.stack((i, j)).astype(np.int64)
|
37 |
+
edge_shift = np.dot(S, cell)
|
38 |
+
|
39 |
+
return edge_index, edge_shift
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def collate_fn(batch: list[Atoms], cutoff: float) -> Data:
|
44 |
+
"""Collate a list of Atoms objects into a single batched Atoms object."""
|
45 |
+
|
46 |
+
# Offset the edge indices for each graph to ensure they remain disconnected
|
47 |
+
offset = 0
|
48 |
+
|
49 |
+
node_batch = []
|
50 |
+
|
51 |
+
numbers_batch = []
|
52 |
+
positions_batch = []
|
53 |
+
# ec_batch = []
|
54 |
+
|
55 |
+
forces_batch = []
|
56 |
+
charges_batch = []
|
57 |
+
magmoms_batch = []
|
58 |
+
dipoles_batch = []
|
59 |
+
|
60 |
+
edge_index_batch = []
|
61 |
+
edge_shift_batch = []
|
62 |
+
|
63 |
+
cell_batch = []
|
64 |
+
natoms_batch = []
|
65 |
+
|
66 |
+
energy_batch = []
|
67 |
+
stress_batch = []
|
68 |
+
|
69 |
+
for i, atoms in enumerate(batch):
|
70 |
+
|
71 |
+
edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False)
|
72 |
+
|
73 |
+
edge_index[0] += offset
|
74 |
+
edge_index[1] += offset
|
75 |
+
edge_index_batch.append(torch.tensor(edge_index))
|
76 |
+
edge_shift_batch.append(torch.tensor(edge_shift))
|
77 |
+
|
78 |
+
natoms = len(atoms)
|
79 |
+
offset += natoms
|
80 |
+
node_batch.append(torch.ones(natoms, dtype=torch.long) * i)
|
81 |
+
natoms_batch.append(natoms)
|
82 |
+
|
83 |
+
cell_batch.append(torch.tensor(atoms.cell.array))
|
84 |
+
numbers_batch.append(torch.tensor(atoms.numbers))
|
85 |
+
positions_batch.append(torch.tensor(atoms.positions))
|
86 |
+
|
87 |
+
# ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers])
|
88 |
+
|
89 |
+
charges_batch.append(
|
90 |
+
atoms.get_initial_charges()
|
91 |
+
if atoms.get_initial_charges().any()
|
92 |
+
else torch.full((natoms,), torch.nan)
|
93 |
+
)
|
94 |
+
magmoms_batch.append(
|
95 |
+
atoms.get_initial_magnetic_moments()
|
96 |
+
if atoms.get_initial_magnetic_moments().any()
|
97 |
+
else torch.full((natoms,), torch.nan)
|
98 |
+
)
|
99 |
+
|
100 |
+
# Create the new 'arrays' data for the batch
|
101 |
+
|
102 |
+
cell_batch = torch.stack(cell_batch, dim=0)
|
103 |
+
node_batch = torch.cat(node_batch, dim=0)
|
104 |
+
positions_batch = torch.cat(positions_batch, dim=0)
|
105 |
+
numbers_batch = torch.cat(numbers_batch, dim=0)
|
106 |
+
natoms_batch = torch.tensor(natoms_batch, dtype=torch.long)
|
107 |
+
|
108 |
+
charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None
|
109 |
+
magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None
|
110 |
+
|
111 |
+
# ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch))
|
112 |
+
# ec_batch = torch.stack(ec_batch, dim=0)
|
113 |
+
|
114 |
+
edge_index_batch = torch.cat(edge_index_batch, dim=1)
|
115 |
+
edge_shift_batch = torch.cat(edge_shift_batch, dim=0)
|
116 |
+
|
117 |
+
arrays_batch_concatenated = {
|
118 |
+
"cell": cell_batch,
|
119 |
+
"positions": positions_batch,
|
120 |
+
"edge_index": edge_index_batch,
|
121 |
+
"edge_shift": edge_shift_batch,
|
122 |
+
"numbers": numbers_batch,
|
123 |
+
"num_nodes": offset,
|
124 |
+
"batch": node_batch,
|
125 |
+
"charges": charges_batch,
|
126 |
+
"magmoms": magmoms_batch,
|
127 |
+
# "ec": ec_batch,
|
128 |
+
"natoms": natoms_batch,
|
129 |
+
"cutoff": torch.tensor(cutoff),
|
130 |
+
}
|
131 |
+
|
132 |
+
# TODO: custom fields
|
133 |
+
|
134 |
+
# Create a new Data object with the concatenated arrays data
|
135 |
+
batch_data = Data.from_dict(arrays_batch_concatenated)
|
136 |
+
|
137 |
+
return batch_data
|
138 |
+
|
139 |
+
|
140 |
+
def decollate_fn(batch_data: Data) -> list[Atoms]:
|
141 |
+
"""Decollate a batched Data object into a list of individual Atoms objects."""
|
142 |
+
|
143 |
+
# FIXME: this function is not working properly when the batch_data is on GPU.
|
144 |
+
# TODO: create a new Cell class using torch tensor to handle device placement.
|
145 |
+
# As a temporary fix, detach the batch_data from the GPU and move it to CPU.
|
146 |
+
batch_data = batch_data.detach().cpu()
|
147 |
+
|
148 |
+
# Initialize empty lists to store individual data entries
|
149 |
+
individual_entries = []
|
150 |
+
|
151 |
+
# Split the 'batch' attribute to identify data entries
|
152 |
+
unique_batches = batch_data.batch.unique(sorted=True)
|
153 |
+
|
154 |
+
for i in unique_batches:
|
155 |
+
# Identify the indices corresponding to the current data entry
|
156 |
+
entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0]
|
157 |
+
|
158 |
+
# Extract the attributes for the current data entry
|
159 |
+
cell = batch_data.cell[i]
|
160 |
+
numbers = batch_data.numbers[entry_indices]
|
161 |
+
positions = batch_data.positions[entry_indices]
|
162 |
+
# edge_index = batch_data.edge_index[:, entry_indices]
|
163 |
+
# edge_shift = batch_data.edge_shift[entry_indices]
|
164 |
+
# batch_data.ec[entry_indices] if batch_data.ec is not None else None
|
165 |
+
|
166 |
+
# Optional fields
|
167 |
+
energy = batch_data.energy[i] if "energy" in batch_data else None
|
168 |
+
forces = batch_data.forces[entry_indices] if "forces" in batch_data else None
|
169 |
+
stress = batch_data.stress[i] if "stress" in batch_data else None
|
170 |
+
|
171 |
+
# charges = batch_data.charges[entry_indices] if "charges" in batch_data else None
|
172 |
+
# magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None
|
173 |
+
# dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None
|
174 |
+
|
175 |
+
# TODO: cumstom fields
|
176 |
+
|
177 |
+
# Create an 'Atoms' object for the current data entry
|
178 |
+
atoms = Atoms(
|
179 |
+
cell=cell,
|
180 |
+
positions=positions,
|
181 |
+
numbers=numbers,
|
182 |
+
# forces=None if torch.any(torch.isnan(forces)) else forces,
|
183 |
+
# charges=None if torch.any(torch.isnan(charges)) else charges,
|
184 |
+
# magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms,
|
185 |
+
# dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles,
|
186 |
+
# energy=None if torch.isnan(energy) else energy,
|
187 |
+
# stress=None if torch.any(torch.isnan(stress)) else stress,
|
188 |
+
)
|
189 |
+
|
190 |
+
atoms.calc = SinglePointCalculator(
|
191 |
+
energy=energy,
|
192 |
+
forces=forces,
|
193 |
+
stress=stress,
|
194 |
+
# charges=charges,
|
195 |
+
# magmoms=magmoms,
|
196 |
+
) # type: ignore
|
197 |
+
|
198 |
+
# Append the individual data entry to the list
|
199 |
+
individual_entries.append(atoms)
|
200 |
+
|
201 |
+
return individual_entries
|
mlip_arena/models/__init__.py
CHANGED
@@ -6,11 +6,21 @@ from pathlib import Path
|
|
6 |
|
7 |
import torch
|
8 |
import yaml
|
9 |
-
from ase import Atoms
|
10 |
-
from ase.calculators.calculator import Calculator, all_changes
|
11 |
from huggingface_hub import PyTorchModelHubMixin
|
12 |
from torch import nn
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# from torch_geometric.data import Data
|
15 |
|
16 |
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
|
@@ -20,14 +30,17 @@ MLIPMap = {}
|
|
20 |
|
21 |
for model, metadata in REGISTRY.items():
|
22 |
try:
|
23 |
-
module = importlib.import_module(
|
|
|
|
|
24 |
MLIPMap[model] = getattr(module, metadata["class"])
|
25 |
except (ModuleNotFoundError, AttributeError, ValueError) as e:
|
26 |
-
|
27 |
continue
|
28 |
|
29 |
MLIPEnum = Enum("MLIPEnum", MLIPMap)
|
30 |
|
|
|
31 |
class MLIP(
|
32 |
nn.Module,
|
33 |
PyTorchModelHubMixin,
|
@@ -35,6 +48,9 @@ class MLIP(
|
|
35 |
):
|
36 |
def __init__(self, model: nn.Module) -> None:
|
37 |
super().__init__()
|
|
|
|
|
|
|
38 |
self.model = model
|
39 |
|
40 |
def forward(self, x):
|
@@ -47,7 +63,9 @@ class MLIPCalculator(MLIP, Calculator):
|
|
47 |
|
48 |
def __init__(
|
49 |
self,
|
50 |
-
model,
|
|
|
|
|
51 |
# ASE Calculator
|
52 |
restart=None,
|
53 |
atoms=None,
|
@@ -60,12 +78,24 @@ class MLIPCalculator(MLIP, Calculator):
|
|
60 |
) # Initialize ASE Calculator part
|
61 |
# Additional initialization if needed
|
62 |
# self.name: str = self.__class__.__name__
|
|
|
|
|
|
|
63 |
# self.device = device or torch.device(
|
64 |
# "cuda" if torch.cuda.is_available() else "cpu"
|
65 |
# )
|
66 |
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
|
67 |
# self.implemented_properties = ["energy", "forces", "stress"]
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def calculate(
|
70 |
self,
|
71 |
atoms: Atoms,
|
@@ -75,7 +105,11 @@ class MLIPCalculator(MLIP, Calculator):
|
|
75 |
"""Calculate energies and forces for the given Atoms object"""
|
76 |
super().calculate(atoms, properties, system_changes)
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
80 |
self.results = {}
|
81 |
if "energy" in properties:
|
@@ -85,13 +119,14 @@ class MLIPCalculator(MLIP, Calculator):
|
|
85 |
if "stress" in properties:
|
86 |
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
|
87 |
|
88 |
-
def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
1. Use `ase.neighborlist.NeighborList` to get neighbor list
|
93 |
-
2. Create `torch_geometric.data.Data` object and copy the data
|
94 |
-
3. Pass the `Data` object to the model and return the output
|
95 |
|
96 |
-
|
97 |
-
raise NotImplementedError
|
|
|
6 |
|
7 |
import torch
|
8 |
import yaml
|
|
|
|
|
9 |
from huggingface_hub import PyTorchModelHubMixin
|
10 |
from torch import nn
|
11 |
|
12 |
+
from ase import Atoms
|
13 |
+
from ase.calculators.calculator import Calculator, all_changes
|
14 |
+
from mlip_arena.data.collate import collate_fn
|
15 |
+
from mlip_arena.models.utils import get_freer_device
|
16 |
+
|
17 |
+
try:
|
18 |
+
from prefect.logging import get_run_logger
|
19 |
+
|
20 |
+
logger = get_run_logger()
|
21 |
+
except (ImportError, RuntimeError):
|
22 |
+
from loguru import logger
|
23 |
+
|
24 |
# from torch_geometric.data import Data
|
25 |
|
26 |
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
|
|
|
30 |
|
31 |
for model, metadata in REGISTRY.items():
|
32 |
try:
|
33 |
+
module = importlib.import_module(
|
34 |
+
f"{__package__}.{metadata['module']}.{metadata['family']}"
|
35 |
+
)
|
36 |
MLIPMap[model] = getattr(module, metadata["class"])
|
37 |
except (ModuleNotFoundError, AttributeError, ValueError) as e:
|
38 |
+
logger.warning(e)
|
39 |
continue
|
40 |
|
41 |
MLIPEnum = Enum("MLIPEnum", MLIPMap)
|
42 |
|
43 |
+
|
44 |
class MLIP(
|
45 |
nn.Module,
|
46 |
PyTorchModelHubMixin,
|
|
|
48 |
):
|
49 |
def __init__(self, model: nn.Module) -> None:
|
50 |
super().__init__()
|
51 |
+
# https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
|
52 |
+
# torch._dynamo.config.compiled_autograd = True
|
53 |
+
# self.model = torch.compile(model)
|
54 |
self.model = model
|
55 |
|
56 |
def forward(self, x):
|
|
|
63 |
|
64 |
def __init__(
|
65 |
self,
|
66 |
+
model: nn.Module,
|
67 |
+
device: torch.device | None = None,
|
68 |
+
cutoff: float = 6.0,
|
69 |
# ASE Calculator
|
70 |
restart=None,
|
71 |
atoms=None,
|
|
|
78 |
) # Initialize ASE Calculator part
|
79 |
# Additional initialization if needed
|
80 |
# self.name: str = self.__class__.__name__
|
81 |
+
self.device = device or get_freer_device()
|
82 |
+
self.cutoff = cutoff
|
83 |
+
self.model.to(self.device)
|
84 |
# self.device = device or torch.device(
|
85 |
# "cuda" if torch.cuda.is_available() else "cpu"
|
86 |
# )
|
87 |
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
|
88 |
# self.implemented_properties = ["energy", "forces", "stress"]
|
89 |
|
90 |
+
# def __getstate__(self):
|
91 |
+
# state = self.__dict__.copy()
|
92 |
+
# state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
|
93 |
+
# return state
|
94 |
+
|
95 |
+
# def __setstate__(self, state):
|
96 |
+
# self.__dict__.update(state)
|
97 |
+
# self.model = torch.compile(state["_modules"]["model"])
|
98 |
+
|
99 |
def calculate(
|
100 |
self,
|
101 |
atoms: Atoms,
|
|
|
105 |
"""Calculate energies and forces for the given Atoms object"""
|
106 |
super().calculate(atoms, properties, system_changes)
|
107 |
|
108 |
+
# TODO: move collate_fn to here in MLIPCalculator
|
109 |
+
data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
|
110 |
+
output = self.forward(data)
|
111 |
+
|
112 |
+
# TODO: decollate_fn
|
113 |
|
114 |
self.results = {}
|
115 |
if "energy" in properties:
|
|
|
119 |
if "stress" in properties:
|
120 |
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
|
121 |
|
122 |
+
# def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
|
123 |
+
# """Implement data conversion, graph creation, and model forward pass
|
124 |
+
|
125 |
+
# Example implementation:
|
126 |
+
# 1. Use `ase.neighborlist.NeighborList` to get neighbor list
|
127 |
+
# 2. Create `torch_geometric.data.Data` object and copy the data
|
128 |
+
# 3. Pass the `Data` object to the model and return the output
|
129 |
|
130 |
+
# """
|
|
|
|
|
|
|
131 |
|
132 |
+
# raise NotImplementedError
|
|
mlip_arena/models/classicals/__init__.py
ADDED
File without changes
|
mlip_arena/models/classicals/zbl.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.linalg as LA
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch_scatter
|
5 |
+
from torch_geometric.data import Data
|
6 |
+
|
7 |
+
from ase.data import covalent_radii
|
8 |
+
from ase.units import _e, _eps0, m, pi
|
9 |
+
from e3nn.util.jit import compile_mode # TODO: e3nn allows autograd in compiled model
|
10 |
+
|
11 |
+
|
12 |
+
@compile_mode("script")
|
13 |
+
class ZBL(nn.Module):
|
14 |
+
"""Ziegler-Biersack-Littmark (ZBL) screened nuclear repulsion"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
trianable: bool = False,
|
19 |
+
**kwargs,
|
20 |
+
) -> None:
|
21 |
+
nn.Module.__init__(self, **kwargs)
|
22 |
+
|
23 |
+
torch.set_default_dtype(torch.double)
|
24 |
+
|
25 |
+
self.a = torch.nn.parameter.Parameter(
|
26 |
+
torch.tensor(
|
27 |
+
[0.18175, 0.50986, 0.28022, 0.02817], dtype=torch.get_default_dtype()
|
28 |
+
),
|
29 |
+
requires_grad=trianable,
|
30 |
+
)
|
31 |
+
self.b = torch.nn.parameter.Parameter(
|
32 |
+
torch.tensor(
|
33 |
+
[-3.19980, -0.94229, -0.40290, -0.20162],
|
34 |
+
dtype=torch.get_default_dtype(),
|
35 |
+
),
|
36 |
+
requires_grad=trianable,
|
37 |
+
)
|
38 |
+
|
39 |
+
self.a0 = torch.nn.parameter.Parameter(
|
40 |
+
torch.tensor(0.46850, dtype=torch.get_default_dtype()),
|
41 |
+
requires_grad=trianable,
|
42 |
+
)
|
43 |
+
|
44 |
+
self.p = torch.nn.parameter.Parameter(
|
45 |
+
torch.tensor(0.23, dtype=torch.get_default_dtype()), requires_grad=trianable
|
46 |
+
)
|
47 |
+
|
48 |
+
self.register_buffer(
|
49 |
+
"covalent_radii",
|
50 |
+
torch.tensor(
|
51 |
+
covalent_radii,
|
52 |
+
dtype=torch.get_default_dtype(),
|
53 |
+
),
|
54 |
+
)
|
55 |
+
|
56 |
+
def phi(self, x):
|
57 |
+
return torch.einsum("i,ij->j", self.a, torch.exp(torch.outer(self.b, x)))
|
58 |
+
|
59 |
+
def d_phi(self, x):
|
60 |
+
return torch.einsum(
|
61 |
+
"i,ij->j", self.a * self.b, torch.exp(torch.outer(self.b, x))
|
62 |
+
)
|
63 |
+
|
64 |
+
def dd_phi(self, x):
|
65 |
+
return torch.einsum(
|
66 |
+
"i,ij->j", self.a * self.b**2, torch.exp(torch.outer(self.b, x))
|
67 |
+
)
|
68 |
+
|
69 |
+
def eij(
|
70 |
+
self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor
|
71 |
+
) -> torch.Tensor: # [eV]
|
72 |
+
return _e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij)
|
73 |
+
|
74 |
+
def d_eij(
|
75 |
+
self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor
|
76 |
+
) -> torch.Tensor: # [eV / A]
|
77 |
+
return -_e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**2)
|
78 |
+
|
79 |
+
def dd_eij(
|
80 |
+
self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor
|
81 |
+
) -> torch.Tensor: # [eV / A^2]
|
82 |
+
return _e * m / (2 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**3)
|
83 |
+
|
84 |
+
def switch_fn(
|
85 |
+
self,
|
86 |
+
zi: torch.Tensor,
|
87 |
+
zj: torch.Tensor,
|
88 |
+
rij: torch.Tensor,
|
89 |
+
aij: torch.Tensor,
|
90 |
+
router: torch.Tensor,
|
91 |
+
rinner: torch.Tensor,
|
92 |
+
) -> torch.Tensor: # [eV]
|
93 |
+
# aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p))
|
94 |
+
|
95 |
+
xrouter = router / aij
|
96 |
+
|
97 |
+
energy = self.eij(zi, zj, router) * self.phi(xrouter)
|
98 |
+
|
99 |
+
grad1 = self.d_eij(zi, zj, router) * self.phi(xrouter) + self.eij(
|
100 |
+
zi, zj, router
|
101 |
+
) * self.d_phi(xrouter)
|
102 |
+
|
103 |
+
grad2 = (
|
104 |
+
self.dd_eij(zi, zj, router) * self.phi(xrouter)
|
105 |
+
+ self.d_eij(zi, zj, router) * self.d_phi(xrouter)
|
106 |
+
+ self.d_eij(zi, zj, router) * self.d_phi(xrouter)
|
107 |
+
+ self.eij(zi, zj, router) * self.dd_phi(xrouter)
|
108 |
+
)
|
109 |
+
|
110 |
+
A = (-3 * grad1 + (router - rinner) * grad2) / (router - rinner) ** 2
|
111 |
+
B = (2 * grad1 - (router - rinner) * grad2) / (router - rinner) ** 3
|
112 |
+
C = (
|
113 |
+
-energy
|
114 |
+
+ 1.0 / 2.0 * (router - rinner) * grad1
|
115 |
+
- 1.0 / 12.0 * (router - rinner) ** 2 * grad2
|
116 |
+
)
|
117 |
+
|
118 |
+
switching = torch.where(
|
119 |
+
rij < rinner,
|
120 |
+
C,
|
121 |
+
A / 3.0 * (rij - rinner) ** 3 + B / 4.0 * (rij - rinner) ** 4 + C,
|
122 |
+
)
|
123 |
+
|
124 |
+
return switching
|
125 |
+
|
126 |
+
def envelope(self, r: torch.Tensor, rc: torch.Tensor, p: int = 6):
|
127 |
+
x = r / rc
|
128 |
+
y = (
|
129 |
+
1.0
|
130 |
+
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p)
|
131 |
+
+ p * (p + 2.0) * torch.pow(x, p + 1)
|
132 |
+
- (p * (p + 1.0) / 2) * torch.pow(x, p + 2)
|
133 |
+
) * (x < 1)
|
134 |
+
return y
|
135 |
+
|
136 |
+
def _get_derivatives(self, energy: torch.Tensor, data: Data):
|
137 |
+
egradi, egradij = torch.autograd.grad(
|
138 |
+
outputs=[energy], # TODO: generalized derivatives
|
139 |
+
inputs=[data.positions, data.vij], # TODO: generalized derivatives
|
140 |
+
grad_outputs=[torch.ones_like(energy)],
|
141 |
+
retain_graph=True,
|
142 |
+
create_graph=True,
|
143 |
+
allow_unused=True,
|
144 |
+
)
|
145 |
+
|
146 |
+
volume = torch.det(data.cell) # (batch,)
|
147 |
+
rfaxy = torch.einsum("ax,ay->axy", data.vij, -egradij)
|
148 |
+
|
149 |
+
edge_batch = data.batch[data.edge_index[0]]
|
150 |
+
|
151 |
+
stress = (
|
152 |
+
-0.5
|
153 |
+
* torch_scatter.scatter_sum(rfaxy, edge_batch, dim=0)
|
154 |
+
/ volume.view(-1, 1)
|
155 |
+
)
|
156 |
+
|
157 |
+
return -egradi, stress
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
data: Data,
|
162 |
+
) -> dict[str, torch.Tensor]:
|
163 |
+
# TODO: generalized derivatives
|
164 |
+
data.positions.requires_grad_(True)
|
165 |
+
|
166 |
+
numbers = data.numbers # (sum(N), )
|
167 |
+
positions = data.positions # (sum(N), 3)
|
168 |
+
edge_index = data.edge_index # (2, sum(E))
|
169 |
+
edge_shift = data.edge_shift # (sum(E), 3)
|
170 |
+
batch = data.batch # (sum(N), )
|
171 |
+
|
172 |
+
edge_src, edge_dst = edge_index[0], edge_index[1]
|
173 |
+
|
174 |
+
if "rij" not in data or "vij" not in data:
|
175 |
+
data.vij = positions[edge_dst] - positions[edge_src] + edge_shift
|
176 |
+
data.rij = LA.norm(data.vij, dim=-1)
|
177 |
+
|
178 |
+
rbond = (
|
179 |
+
self.covalent_radii[numbers[edge_src]]
|
180 |
+
+ self.covalent_radii[numbers[edge_dst]]
|
181 |
+
)
|
182 |
+
|
183 |
+
rij = data.rij
|
184 |
+
zi = numbers[edge_src] # (sum(E), )
|
185 |
+
zj = numbers[edge_dst] # (sum(E), )
|
186 |
+
|
187 |
+
aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) # (sum(E), )
|
188 |
+
|
189 |
+
energy_pairs = (
|
190 |
+
self.eij(zi, zj, rij)
|
191 |
+
* self.phi(rij / aij.to(rij))
|
192 |
+
* self.envelope(rij, torch.min(data.cutoff, rbond))
|
193 |
+
)
|
194 |
+
|
195 |
+
energy_nodes = 0.5 * torch_scatter.scatter_add(
|
196 |
+
src=energy_pairs,
|
197 |
+
index=edge_dst,
|
198 |
+
dim=0,
|
199 |
+
) # (sum(N), )
|
200 |
+
|
201 |
+
energies = torch_scatter.scatter_add(
|
202 |
+
src=energy_nodes,
|
203 |
+
index=batch,
|
204 |
+
dim=0,
|
205 |
+
) # (B, )
|
206 |
+
|
207 |
+
# TODO: generalized derivatives
|
208 |
+
forces, stress = self._get_derivatives(energies, data)
|
209 |
+
|
210 |
+
return {
|
211 |
+
"energy": energies,
|
212 |
+
"forces": forces,
|
213 |
+
"stress": stress,
|
214 |
+
}
|
mlip_arena/models/registry.yaml
CHANGED
@@ -84,6 +84,7 @@ MatterSim:
|
|
84 |
- eos_alloy
|
85 |
gpu-tasks:
|
86 |
- homonuclear-diatomics
|
|
|
87 |
github: https://github.com/microsoft/mattersim
|
88 |
doi: https://arxiv.org/abs/2405.04967
|
89 |
date: 2024-12-05
|
@@ -264,6 +265,7 @@ ALIGNN:
|
|
264 |
- MP22
|
265 |
gpu-tasks:
|
266 |
- homonuclear-diatomics
|
|
|
267 |
# - combustion
|
268 |
prediction: EFS
|
269 |
nvt: true
|
@@ -309,6 +311,7 @@ ORBv2:
|
|
309 |
gpu-tasks:
|
310 |
- homonuclear-diatomics
|
311 |
- combustion
|
|
|
312 |
github: https://github.com/orbital-materials/orb-models
|
313 |
doi:
|
314 |
date: 2024-10-15
|
|
|
84 |
- eos_alloy
|
85 |
gpu-tasks:
|
86 |
- homonuclear-diatomics
|
87 |
+
- stability
|
88 |
github: https://github.com/microsoft/mattersim
|
89 |
doi: https://arxiv.org/abs/2405.04967
|
90 |
date: 2024-12-05
|
|
|
265 |
- MP22
|
266 |
gpu-tasks:
|
267 |
- homonuclear-diatomics
|
268 |
+
- stability
|
269 |
# - combustion
|
270 |
prediction: EFS
|
271 |
nvt: true
|
|
|
311 |
gpu-tasks:
|
312 |
- homonuclear-diatomics
|
313 |
- combustion
|
314 |
+
- stability
|
315 |
github: https://github.com/orbital-materials/orb-models
|
316 |
doi:
|
317 |
date: 2024-10-15
|
mlip_arena/tasks/optimize.py
CHANGED
@@ -111,6 +111,9 @@ def run(
|
|
111 |
logger.info(f"Criterion: {pformat(criterion)}")
|
112 |
optimizer_instance.run(**criterion)
|
113 |
|
|
|
114 |
return {
|
115 |
"atoms": atoms,
|
|
|
|
|
116 |
}
|
|
|
111 |
logger.info(f"Criterion: {pformat(criterion)}")
|
112 |
optimizer_instance.run(**criterion)
|
113 |
|
114 |
+
|
115 |
return {
|
116 |
"atoms": atoms,
|
117 |
+
"steps": optimizer_instance.nsteps,
|
118 |
+
"converged": optimizer_instance.converged(),
|
119 |
}
|
tests/test_internal_calculators.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from mlip_arena.models import MLIPCalculator
|
3 |
+
from mlip_arena.models.classicals.zbl import ZBL
|
4 |
+
|
5 |
+
from ase.build import bulk
|
6 |
+
|
7 |
+
|
8 |
+
def test_zbl():
|
9 |
+
calc = MLIPCalculator(model=ZBL(), cutoff=6.0)
|
10 |
+
|
11 |
+
energies = []
|
12 |
+
forces = []
|
13 |
+
stresses = []
|
14 |
+
|
15 |
+
lattice_constants = [1, 3, 5, 7]
|
16 |
+
|
17 |
+
for a in lattice_constants:
|
18 |
+
atoms = bulk("Cu", "fcc", a=a) * (2, 2, 2)
|
19 |
+
atoms.calc = calc
|
20 |
+
|
21 |
+
energies.append(atoms.get_potential_energy())
|
22 |
+
forces.append(atoms.get_forces())
|
23 |
+
stresses.append(atoms.get_stress(voigt=False))
|
24 |
+
|
25 |
+
# test energy monotonicity
|
26 |
+
assert all(np.diff(energies) <= 0), "Energy is not monotonically decreasing with increasing lattice constant"
|
27 |
+
|
28 |
+
# test force vectors are all zeros due to symmetry
|
29 |
+
for f in forces:
|
30 |
+
assert np.allclose(f, 0), "Forces should be zero due to symmetry"
|
31 |
+
|
32 |
+
# test trace of stress is monotonically increasing (less negative) and zero beyond cutoff
|
33 |
+
traces = [np.trace(s) for s in stresses]
|
34 |
+
|
35 |
+
assert all(np.diff(traces) >= 0), "Trace of stress is not monotonically increasing with increasing lattice constant"
|
36 |
+
assert np.allclose(stresses[-1], 0), "Stress should be zero beyond cutoff"
|