import rootutils import torch from torch import nn from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential from torch_geometric.loader import DataLoader from torch_geometric.nn import MessagePassing from torch_scatter import scatter # setup root dir and pythonpath rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from src.data.components.pinder_dataset import PinderDataset from src.models.components.utils import ( compute_euler_angles_from_rotation_matrices, compute_rotation_matrix_from_ortho6d, ) class EquivariantMPNNLayer(MessagePassing): def __init__(self, emb_dim=64, out_dim=128, aggr="add"): r"""Message Passing Neural Network Layer This layer is equivariant to 3D rotations and translations. Args: emb_dim: (int) - hidden dimension d edge_dim: (int) - edge feature dimension d_e aggr: (str) - aggregation function \oplus (sum/mean/max) """ # Set the aggregation function super().__init__(aggr=aggr) self.emb_dim = emb_dim # self.mlp_msg = Sequential( Linear(2 * emb_dim + 1, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), ) self.mlp_pos = Sequential( Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1) ) # MLP \psi self.mlp_upd = Sequential( Linear(2 * emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), ) # MLP \phi # =========================================== self.lin_out = Linear(emb_dim, out_dim) def forward(self, data): """ The forward pass updates node features h via one round of message passing. Args: h: (n, d) - initial node features pos: (n, 3) - initial node coordinates edge_index: (e, 2) - pairs of edges (i, j) edge_attr: (e, d_e) - edge features Returns: out: [(n, d),(n,3)] - updated node features """ # h, pos, edge_index = data h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos) h_out = self.lin_out(h_out) return h_out, pos_out, edge_index # ========================================== # def message(self, h_i, h_j, pos_i, pos_j): # Compute distance between nodes i and j (Euclidean distance) # distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1) pos_diff = pos_i - pos_j dists = torch.norm(pos_diff, dim=-1).unsqueeze(1) # Concatenate node features, edge features, and distance msg = torch.cat([h_i, h_j, dists], dim=-1) msg = self.mlp_msg(msg) pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1) # (e, d) return msg, pos_diff # ... # def aggregate(self, inputs, index): """The aggregate function aggregates the messages from neighboring nodes, according to the chosen aggregation function ('sum' by default). Args: inputs: (e, d) - messages m_ij from destination to source nodes index: (e, 1) - list of source nodes for each edge/message in input Returns: aggr_out: (n, d) - aggregated messages m_i """ msgs, pos_diffs = inputs msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr) pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean") return msg_aggr, pos_aggr def update(self, aggr_out, h, pos): msg_aggr, pos_aggr = aggr_out upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1)) upd_pos = pos + pos_aggr return upd_out, upd_pos def __repr__(self) -> str: return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" class PinderMPNNModel(Module): def __init__(self, input_dim=1, emb_dim=64, num_heads=5): """Message Passing Neural Network model for graph property prediction This model uses both node features and coordinates as inputs, and is invariant to 3D rotations and translations (the constituent MPNN layers are equivariant to 3D rotations and translations). Args: emb_dim: (int) - hidden dimension d input_dim: (int) - initial node feature dimension d_n edge_dim: (int) - edge feature dimension d_e out_dim: (int) - output dimension (fixed to 1) """ super().__init__() # Linear projection for initial node features self.lin_in_rec = Linear(input_dim, emb_dim) self.lin_in_lig = Linear(input_dim, emb_dim) # Stack of MPNN layers self.receptor_mpnn = Sequential( EquivariantMPNNLayer(emb_dim, 128, aggr="mean"), EquivariantMPNNLayer(128, 256, aggr="mean"), # EquivariantMPNNLayer(256, 512, aggr="mean"), # EquivariantMPNNLayer(512, 512, aggr="mean"), ) self.ligand_mpnn = Sequential( EquivariantMPNNLayer(64, 128, aggr="mean"), EquivariantMPNNLayer(128, 256, aggr="mean"), # EquivariantMPNNLayer(256, 512, aggr="mean"), # EquivariantMPNNLayer(512, 512, aggr="mean"), ) # Cross-attention layer self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True) self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True) # MLPs for translation prediction self.fc_translation_rec = nn.Linear(256 + 3, 3) self.fc_translation_lig = nn.Linear(256 + 3, 3) def forward(self, batch): """ The main forward pass of the model. Args: batch: Same as in forward_rot_trans. Returns: transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3) representing the transformed ligand coordinates after applying the predicted rotation and translation. """ h_receptor = self.lin_in_rec(batch["receptor"].x) h_ligand = self.lin_in_lig(batch["ligand"].x) pos_receptor = batch["receptor"].pos pos_ligand = batch["ligand"].pos h_receptor, pos_receptor, _ = self.receptor_mpnn( (h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index) ) h_ligand, pos_ligand, _ = self.ligand_mpnn( (h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index) ) attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand) attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor) emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1) emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1) translation_vector_r = self.fc_translation_rec(emb_features_receptor) translation_vector_l = self.fc_translation_lig(emb_features_ligand) ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec) ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig) receptor_coords = ( compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi ) ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi receptor_coords = receptor_coords + translation_vector_r ligand_coords = ligand_coords + translation_vector_l return receptor_coords, ligand_coords if __name__ == "__main__": file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"] dataset = PinderDataset(file_paths=file_paths * 3) loader = DataLoader(dataset, batch_size=3, shuffle=False) batch = next(iter(loader)) model = PinderMPNNModel() print("Number of parameters:", sum(p.numel() for p in model.parameters())) receptor_coords, ligand_coords = model(batch) print(receptor_coords.shape) print(ligand_coords.shape)