from torch import nn import torch from .encoder import GotenNet from .utils import get_symmetric_displacement, BatchedPeriodicDistance, ACT_CLASS_MAPPING #from torch_scatter import scatter class NodeInvariantReadout(nn.Module): def __init__(self, in_channels, num_residues, hidden_channels, out_channels, activation): super().__init__() self.linears = nn.ModuleList([nn.Linear(in_channels, out_channels) for _ in range(num_residues - 1)]) # Define the nonlinear layer for the last layer's output self.non_linear = nn.Sequential( nn.Linear(in_channels, hidden_channels), ACT_CLASS_MAPPING[activation](), nn.Linear(hidden_channels, out_channels), ) def forward(self, embedding_0): layer_outputs = embedding_0.squeeze(2) # [n_nodes, in_channels, num_residues] processed_outputs = [] for i, linear in enumerate(self.linears): processed_outputs.append(linear(layer_outputs[:, :, i])) processed_outputs.append(self.non_linear(layer_outputs[:, :, -1])) output = torch.stack(processed_outputs, dim=0).sum(dim=0).squeeze(-1) return output class PosEGNN(nn.Module): def __init__(self, config): super().__init__() self.distance = BatchedPeriodicDistance(config["encoder"]["cutoff"]) self.encoder = GotenNet(**config["encoder"]) self.readout = NodeInvariantReadout(**config["decoder"]) self.register_buffer("e0_mean", torch.tensor(config["e0_mean"])) self.register_buffer("atomic_res_total_mean", torch.tensor(config["atomic_res_total_mean"])) self.register_buffer("atomic_res_total_std", torch.tensor(config["atomic_res_total_std"])) def forward(self, data): data.pos.requires_grad_(True) data.pos, data.box, data.displacements = get_symmetric_displacement(data.pos, data.box, data.num_graphs, data.batch) data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec, data.cutoff_shifts_idx = self.distance( data.pos, data.box, data.batch ) embedding_dict = self.encoder(data.z, data.pos, data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec) return embedding_dict def compute_properties(self, data, compute_stress = True): output = {} embedding_dict = self.forward(data) embedding_0 = embedding_dict["embedding_0"] # Compute energy node_e_res = self.readout(embedding_0) node_e_res = node_e_res * self.atomic_res_total_std + self.atomic_res_total_mean total_e_res = scatter(src=node_e_res, index=data["batch"], dim=0, reduce="sum") node_e0 = self.e0_mean[data.z] total_e0 = scatter(src=node_e0, index=data["batch"], dim=0, reduce="sum") total_energy = total_e0 + total_e_res output["total_energy"] = total_energy # Compute gradients if compute_stress: inputs = [data.pos, data.displacements] compute_stress = True else: inputs = [data.pos] grad_outputs = torch.autograd.grad( outputs=[total_energy], inputs=inputs, grad_outputs=[torch.ones_like(total_energy)], retain_graph=self.training, create_graph=self.training, ) # Get forces and stresses if compute_stress: force, virial = grad_outputs stress = virial / torch.det(data.box).abs().view(-1, 1, 1) stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) output["force"] = -force output["stress"] = -stress else: force = grad_outputs[0] output["force"] = -force return output