import numpy as np import torch # TODO: consider using vesin from matscipy.neighbours import neighbour_list from torch_geometric.data import Data from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator def get_neighbor( atoms: Atoms, cutoff: float, self_interaction: bool = False ): pbc = atoms.pbc cell = atoms.cell.array i, j, S = neighbour_list( quantities="ijS", pbc=pbc, cell=cell, positions=atoms.positions, cutoff=cutoff ) if not self_interaction: # Eliminate self-edges that don't cross periodic boundaries true_self_edge = i == j true_self_edge &= np.all(S == 0, axis=1) keep_edge = ~true_self_edge i = i[keep_edge] j = j[keep_edge] S = S[keep_edge] edge_index = np.stack((i, j)).astype(np.int64) edge_shift = np.dot(S, cell) return edge_index, edge_shift def collate_fn(batch: list[Atoms], cutoff: float) -> Data: """Collate a list of Atoms objects into a single batched Atoms object.""" # Offset the edge indices for each graph to ensure they remain disconnected offset = 0 node_batch = [] numbers_batch = [] positions_batch = [] # ec_batch = [] forces_batch = [] charges_batch = [] magmoms_batch = [] dipoles_batch = [] edge_index_batch = [] edge_shift_batch = [] cell_batch = [] natoms_batch = [] energy_batch = [] stress_batch = [] for i, atoms in enumerate(batch): edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False) edge_index[0] += offset edge_index[1] += offset edge_index_batch.append(torch.tensor(edge_index)) edge_shift_batch.append(torch.tensor(edge_shift)) natoms = len(atoms) offset += natoms node_batch.append(torch.ones(natoms, dtype=torch.long) * i) natoms_batch.append(natoms) cell_batch.append(torch.tensor(atoms.cell.array)) numbers_batch.append(torch.tensor(atoms.numbers)) positions_batch.append(torch.tensor(atoms.positions)) # ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers]) charges_batch.append( atoms.get_initial_charges() if atoms.get_initial_charges().any() else torch.full((natoms,), torch.nan) ) magmoms_batch.append( atoms.get_initial_magnetic_moments() if atoms.get_initial_magnetic_moments().any() else torch.full((natoms,), torch.nan) ) # Create the new 'arrays' data for the batch cell_batch = torch.stack(cell_batch, dim=0) node_batch = torch.cat(node_batch, dim=0) positions_batch = torch.cat(positions_batch, dim=0) numbers_batch = torch.cat(numbers_batch, dim=0) natoms_batch = torch.tensor(natoms_batch, dtype=torch.long) charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None # ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch)) # ec_batch = torch.stack(ec_batch, dim=0) edge_index_batch = torch.cat(edge_index_batch, dim=1) edge_shift_batch = torch.cat(edge_shift_batch, dim=0) arrays_batch_concatenated = { "cell": cell_batch, "positions": positions_batch, "edge_index": edge_index_batch, "edge_shift": edge_shift_batch, "numbers": numbers_batch, "num_nodes": offset, "batch": node_batch, "charges": charges_batch, "magmoms": magmoms_batch, # "ec": ec_batch, "natoms": natoms_batch, "cutoff": torch.tensor(cutoff), } # TODO: custom fields # Create a new Data object with the concatenated arrays data batch_data = Data.from_dict(arrays_batch_concatenated) return batch_data def decollate_fn(batch_data: Data) -> list[Atoms]: """Decollate a batched Data object into a list of individual Atoms objects.""" # FIXME: this function is not working properly when the batch_data is on GPU. # TODO: create a new Cell class using torch tensor to handle device placement. # As a temporary fix, detach the batch_data from the GPU and move it to CPU. batch_data = batch_data.detach().cpu() # Initialize empty lists to store individual data entries individual_entries = [] # Split the 'batch' attribute to identify data entries unique_batches = batch_data.batch.unique(sorted=True) for i in unique_batches: # Identify the indices corresponding to the current data entry entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0] # Extract the attributes for the current data entry cell = batch_data.cell[i] numbers = batch_data.numbers[entry_indices] positions = batch_data.positions[entry_indices] # edge_index = batch_data.edge_index[:, entry_indices] # edge_shift = batch_data.edge_shift[entry_indices] # batch_data.ec[entry_indices] if batch_data.ec is not None else None # Optional fields energy = batch_data.energy[i] if "energy" in batch_data else None forces = batch_data.forces[entry_indices] if "forces" in batch_data else None stress = batch_data.stress[i] if "stress" in batch_data else None # charges = batch_data.charges[entry_indices] if "charges" in batch_data else None # magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None # dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None # TODO: cumstom fields # Create an 'Atoms' object for the current data entry atoms = Atoms( cell=cell, positions=positions, numbers=numbers, # forces=None if torch.any(torch.isnan(forces)) else forces, # charges=None if torch.any(torch.isnan(charges)) else charges, # magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms, # dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles, # energy=None if torch.isnan(energy) else energy, # stress=None if torch.any(torch.isnan(stress)) else stress, ) atoms.calc = SinglePointCalculator( energy=energy, forces=forces, stress=stress, # charges=charges, # magmoms=magmoms, ) # type: ignore # Append the individual data entry to the list individual_entries.append(atoms) return individual_entries