Spaces:
Sleeping
Sleeping
File size: 8,307 Bytes
b38c7b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import rootutils
import torch
from torch import nn
from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
# setup root dir and pythonpath
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from src.data.components.pinder_dataset import PinderDataset
from src.models.components.utils import (
compute_euler_angles_from_rotation_matrices,
compute_rotation_matrix_from_ortho6d,
)
class EquivariantMPNNLayer(MessagePassing):
def __init__(self, emb_dim=64, out_dim=128, aggr="add"):
r"""Message Passing Neural Network Layer
This layer is equivariant to 3D rotations and translations.
Args:
emb_dim: (int) - hidden dimension d
edge_dim: (int) - edge feature dimension d_e
aggr: (str) - aggregation function \oplus (sum/mean/max)
"""
# Set the aggregation function
super().__init__(aggr=aggr)
self.emb_dim = emb_dim
#
self.mlp_msg = Sequential(
Linear(2 * emb_dim + 1, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
Linear(emb_dim, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
)
self.mlp_pos = Sequential(
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1)
) # MLP \psi
self.mlp_upd = Sequential(
Linear(2 * emb_dim, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
Linear(emb_dim, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
) # MLP \phi
# ===========================================
self.lin_out = Linear(emb_dim, out_dim)
def forward(self, data):
"""
The forward pass updates node features h via one round of message passing.
Args:
h: (n, d) - initial node features
pos: (n, 3) - initial node coordinates
edge_index: (e, 2) - pairs of edges (i, j)
edge_attr: (e, d_e) - edge features
Returns:
out: [(n, d),(n,3)] - updated node features
"""
#
h, pos, edge_index = data
h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos)
h_out = self.lin_out(h_out)
return h_out, pos_out, edge_index
# ==========================================
#
def message(self, h_i, h_j, pos_i, pos_j):
# Compute distance between nodes i and j (Euclidean distance)
# distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1)
pos_diff = pos_i - pos_j
dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
# Concatenate node features, edge features, and distance
msg = torch.cat([h_i, h_j, dists], dim=-1)
msg = self.mlp_msg(msg)
pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1)
# (e, d)
return msg, pos_diff
# ...
#
def aggregate(self, inputs, index):
"""The aggregate function aggregates the messages from neighboring nodes,
according to the chosen aggregation function ('sum' by default).
Args:
inputs: (e, d) - messages m_ij from destination to source nodes
index: (e, 1) - list of source nodes for each edge/message in input
Returns:
aggr_out: (n, d) - aggregated messages m_i
"""
msgs, pos_diffs = inputs
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
return msg_aggr, pos_aggr
def update(self, aggr_out, h, pos):
msg_aggr, pos_aggr = aggr_out
upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1))
upd_pos = pos + pos_aggr
return upd_out, upd_pos
def __repr__(self) -> str:
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
class PinderMPNNModel(Module):
def __init__(self, input_dim=1, emb_dim=64, num_heads=5):
"""Message Passing Neural Network model for graph property prediction
This model uses both node features and coordinates as inputs, and
is invariant to 3D rotations and translations (the constituent MPNN layers
are equivariant to 3D rotations and translations).
Args:
emb_dim: (int) - hidden dimension d
input_dim: (int) - initial node feature dimension d_n
edge_dim: (int) - edge feature dimension d_e
out_dim: (int) - output dimension (fixed to 1)
"""
super().__init__()
# Linear projection for initial node features
self.lin_in_rec = Linear(input_dim, emb_dim)
self.lin_in_lig = Linear(input_dim, emb_dim)
# Stack of MPNN layers
self.receptor_mpnn = Sequential(
EquivariantMPNNLayer(emb_dim, 128, aggr="mean"),
EquivariantMPNNLayer(128, 256, aggr="mean"),
# EquivariantMPNNLayer(256, 512, aggr="mean"),
# EquivariantMPNNLayer(512, 512, aggr="mean"),
)
self.ligand_mpnn = Sequential(
EquivariantMPNNLayer(64, 128, aggr="mean"),
EquivariantMPNNLayer(128, 256, aggr="mean"),
# EquivariantMPNNLayer(256, 512, aggr="mean"),
# EquivariantMPNNLayer(512, 512, aggr="mean"),
)
# Cross-attention layer
self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
# MLPs for translation prediction
self.fc_translation_rec = nn.Linear(256 + 3, 3)
self.fc_translation_lig = nn.Linear(256 + 3, 3)
def forward(self, batch):
"""
The main forward pass of the model.
Args:
batch: Same as in forward_rot_trans.
Returns:
transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
representing the transformed ligand coordinates after applying the predicted
rotation and translation.
"""
h_receptor = self.lin_in_rec(batch["receptor"].x)
h_ligand = self.lin_in_lig(batch["ligand"].x)
pos_receptor = batch["receptor"].pos
pos_ligand = batch["ligand"].pos
h_receptor, pos_receptor, _ = self.receptor_mpnn(
(h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index)
)
h_ligand, pos_ligand, _ = self.ligand_mpnn(
(h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index)
)
attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand)
attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor)
emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1)
emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1)
translation_vector_r = self.fc_translation_rec(emb_features_receptor)
translation_vector_l = self.fc_translation_lig(emb_features_ligand)
ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec)
ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig)
receptor_coords = (
compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi
)
ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi
receptor_coords = receptor_coords + translation_vector_r
ligand_coords = ligand_coords + translation_vector_l
return receptor_coords, ligand_coords
if __name__ == "__main__":
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
dataset = PinderDataset(file_paths=file_paths * 3)
loader = DataLoader(dataset, batch_size=3, shuffle=False)
batch = next(iter(loader))
model = PinderMPNNModel()
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
receptor_coords, ligand_coords = model(batch)
print(receptor_coords.shape)
print(ligand_coords.shape)
|