Sukanyaaa commited on
Commit
c618657
·
1 Parent(s): 5d7c203

Initial commit

Browse files
EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0196c51cf2e21a93906785c5ec4f3aef72d85b34908825b9c70cf29cc35d4fca
3
+ size 556424
inference_app.py CHANGED
@@ -1,29 +1,849 @@
1
-
2
  import time
3
  import json
4
-
5
  import gradio as gr
6
-
7
  from gradio_molecule3d import Molecule3D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
9
 
 
 
 
10
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2, input_protein_2):
13
  start_time = time.time()
14
- # Do inference here
 
 
 
 
15
  # return an output pdb file with the protein and two chains A and B.
16
  # also return a JSON with any metrics you want to report
17
  metrics = {"mean_plddt": 80, "binding_affinity": 2}
18
  end_time = time.time()
19
  run_time = end_time - start_time
20
- return "test_out.pdb",json.dumps(metrics), run_time
21
 
22
  with gr.Blocks() as app:
23
 
24
  gr.Markdown("# Template for inference")
25
 
26
- gr.Markdown("Title, description, and other information about the model")
27
  with gr.Row():
28
  with gr.Column():
29
  input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
@@ -94,3 +914,4 @@ with gr.Blocks() as app:
94
  btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2], outputs=[out, metrics, run_time])
95
 
96
  app.launch()
 
 
1
+ from __future__ import annotations
2
  import time
3
  import json
 
4
  import gradio as gr
 
5
  from gradio_molecule3d import Molecule3D
6
+ import torch
7
+ from pinder.core import get_pinder_location
8
+ get_pinder_location()
9
+ from pytorch_lightning import LightningModule
10
+
11
+ import torch
12
+ import lightning.pytorch as pl
13
+ import torch.nn.functional as F
14
+
15
+ import torch.nn as nn
16
+ import torchmetrics
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch_geometric.nn import MessagePassing
20
+ from torch_geometric.nn import global_mean_pool
21
+ from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
22
+ from torch_scatter import scatter
23
+ from torch.nn import Module
24
+
25
+
26
+ import pinder.core as pinder
27
+ pinder.__version__
28
+ from torch_geometric.loader import DataLoader
29
+ from pinder.core.loader.dataset import get_geo_loader
30
+ from pinder.core import download_dataset
31
+ from pinder.core import get_index
32
+ from pinder.core import get_metadata
33
+ from pathlib import Path
34
+ import pandas as pd
35
+ from pinder.core import PinderSystem
36
+ import torch
37
+ from pinder.core.loader.dataset import PPIDataset
38
+ from pinder.core.loader.geodata import NodeRepresentation
39
+ import pickle
40
+ from pinder.core import get_index, PinderSystem
41
+ from torch_geometric.data import HeteroData
42
+ import os
43
+
44
+ from enum import Enum
45
+
46
+ import numpy as np
47
+ import torch
48
+ import lightning.pytorch as pl
49
+ from numpy.typing import NDArray
50
+ from torch_geometric.data import HeteroData
51
+
52
+ from pinder.core.index.system import PinderSystem
53
+ from pinder.core.loader.structure import Structure
54
+ from pinder.core.utils import constants as pc
55
+ from pinder.core.utils.log import setup_logger
56
+ from pinder.core.index.system import _align_monomers_with_mask
57
+ from pinder.core.loader.structure import Structure
58
+
59
+ import torch
60
+ import torch.nn as nn
61
+ import torch.nn.functional as F
62
+ from torch_geometric.nn import MessagePassing
63
+ from torch_geometric.nn import global_mean_pool
64
+ from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
65
+ from torch_scatter import scatter
66
+ from torch.nn import Module
67
+ import time
68
+ from torch_geometric.nn import global_max_pool
69
+ import copy
70
+ import inspect
71
+ import warnings
72
+ from typing import Optional, Tuple, Union
73
+
74
+ import torch
75
+ from torch import Tensor
76
+
77
+ from torch_geometric.data import Data, Dataset, HeteroData
78
+ from torch_geometric.data.feature_store import FeatureStore
79
+ from torch_geometric.data.graph_store import GraphStore
80
+ from torch_geometric.loader import (
81
+ LinkLoader,
82
+ LinkNeighborLoader,
83
+ NeighborLoader,
84
+ NodeLoader,
85
+ )
86
+ from torch_geometric.loader.dataloader import DataLoader
87
+ from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes
88
+ from torch_geometric.sampler import BaseSampler, NeighborSampler
89
+ from torch_geometric.typing import InputEdges, InputNodes
90
+
91
+ try:
92
+ from lightning.pytorch import LightningDataModule as PLLightningDataModule
93
+ no_pytorch_lightning = False
94
+ except (ImportError, ModuleNotFoundError):
95
+ PLLightningDataModule = object
96
+ no_pytorch_lightning = True
97
+
98
+ from lightning.pytorch.callbacks import ModelCheckpoint
99
+ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
100
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
101
+ from torch_geometric.data.lightning.datamodule import LightningDataset
102
+ from pytorch_lightning.loggers.wandb import WandbLogger
103
+ def get_system(system_id: str) -> PinderSystem:
104
+ return PinderSystem(system_id)
105
+ from Bio import PDB
106
+
107
+ def extract_coordinates_from_pdb(filename):
108
+ """
109
+ Extracts atom coordinates from a PDB file and returns them as a list of tuples.
110
+ Each tuple contains (x, y, z) coordinates of an atom.
111
+ """
112
+ parser = PDB.PDBParser(QUIET=True)
113
+ structure = parser.get_structure("structure", filename)
114
+
115
+ coordinates = []
116
+
117
+ # Loop through each model, chain, residue, and atom to collect coordinates
118
+ for model in structure:
119
+ for chain in model:
120
+ for residue in chain:
121
+ # Retrieve atoms and their coordinates
122
+ for atom in residue:
123
+ xyz = atom.coord # Coordinates are in a numpy array
124
+ # Append the coordinates (x, y, z) as a tuple
125
+ coordinates.append((xyz[0], xyz[1], xyz[2]))
126
+
127
+ return coordinates
128
+ log = setup_logger(__name__)
129
+
130
+ try:
131
+ from torch_cluster import knn_graph
132
+
133
+ torch_cluster_installed = True
134
+ except ImportError as e:
135
+ log.warning(
136
+ "torch-cluster is not installed!"
137
+ "Please install the appropriate library for your pytorch installation."
138
+ "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
139
+ )
140
+ torch_cluster_installed = False
141
+
142
+
143
+ def structure2tensor(
144
+ atom_coordinates: NDArray[np.double] | None = None,
145
+ atom_types: NDArray[np.str_] | None = None,
146
+ element_types: NDArray[np.str_] | None = None,
147
+ residue_coordinates: NDArray[np.double] | None = None,
148
+ residue_ids: NDArray[np.int_] | None = None,
149
+ residue_types: NDArray[np.str_] | None = None,
150
+ chain_ids: NDArray[np.str_] | None = None,
151
+ dtype: torch.dtype = torch.float32,
152
+ ) -> dict[str, torch.Tensor]:
153
+ property_dict = {}
154
+ if atom_types is not None:
155
+ unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1
156
+ types_array_at = np.zeros((len(atom_types), 1))
157
+ for i, name in enumerate(atom_types):
158
+ types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx)
159
+ property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype)
160
+ if element_types is not None:
161
+ types_array_ele = np.zeros((len(element_types), 1))
162
+ for i, name in enumerate(element_types):
163
+ types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"])
164
+ property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype)
165
+ if residue_types is not None:
166
+ unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1
167
+ types_array_res = np.zeros((len(residue_types), 1))
168
+ for i, name in enumerate(residue_types):
169
+ types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx)
170
+ property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype)
171
+
172
+ if atom_coordinates is not None:
173
+ property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype)
174
+
175
+ if residue_coordinates is not None:
176
+ property_dict["residue_coordinates"] = torch.tensor(
177
+ residue_coordinates, dtype=dtype
178
+ )
179
+ if residue_ids is not None:
180
+ property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype)
181
+ if chain_ids is not None:
182
+ property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype)
183
+ property_dict["chain_ids"][chain_ids == "L"] = 1
184
+ return property_dict
185
+
186
+
187
+ class NodeRepresentation(Enum):
188
+ Surface = "surface"
189
+ Atom = "atom"
190
+ Residue = "residue"
191
+
192
+
193
+ class PairedPDB(HeteroData): # type: ignore
194
+ @classmethod
195
+ def from_tuple_system(
196
+ cls,
197
+
198
+ tupal: tuple = (Structure , Structure , Structure),
199
+
200
+ add_edges: bool = True,
201
+ k: int = 10,
202
+
203
+ ) -> PairedPDB:
204
+ return cls.from_structure_pair(
205
+
206
+ holo=tupal[0],
207
+ apo=tupal[1],
208
+ add_edges=add_edges,
209
+ k=k,
210
+ )
211
+
212
+ @classmethod
213
+ def from_structure_pair(
214
+ cls,
215
+
216
+ holo: Structure,
217
+ apo: Structure,
218
+
219
+ add_edges: bool = True,
220
+ k: int = 10,
221
+ ) -> PairedPDB:
222
+ graph = cls()
223
+ holo_calpha = holo.filter("atom_name", mask=["CA"])
224
+ apo_calpha = apo.filter("atom_name", mask=["CA"])
225
+ r_h = (holo.dataframe['chain_id'] == 'R').sum()
226
+ r_a = (apo.dataframe['chain_id'] == 'R').sum()
227
+
228
+ holo_r_props = structure2tensor(
229
+ atom_coordinates=holo.coords[:r_h],
230
+ atom_types=holo.atom_array.atom_name[:r_h],
231
+ element_types=holo.atom_array.element[:r_h],
232
+ residue_coordinates=holo_calpha.coords[:r_h],
233
+ residue_types=holo_calpha.atom_array.res_name[:r_h],
234
+ residue_ids=holo_calpha.atom_array.res_id[:r_h],
235
+ )
236
+ holo_l_props = structure2tensor(
237
+ atom_coordinates=holo.coords[r_h:],
238
+
239
+ atom_types=holo.atom_array.atom_name[r_h:],
240
+ element_types=holo.atom_array.element[r_h:],
241
+ residue_coordinates=holo_calpha.coords[r_h:],
242
+ residue_types=holo_calpha.atom_array.res_name[r_h:],
243
+ residue_ids=holo_calpha.atom_array.res_id[r_h:],
244
+ )
245
+ apo_r_props = structure2tensor(
246
+ atom_coordinates=apo.coords[:r_a],
247
+ atom_types=apo.atom_array.atom_name[:r_a],
248
+ element_types=apo.atom_array.element[:r_a],
249
+ residue_coordinates=apo_calpha.coords[:r_a],
250
+ residue_types=apo_calpha.atom_array.res_name[:r_a],
251
+ residue_ids=apo_calpha.atom_array.res_id[:r_a],
252
+ )
253
+ apo_l_props = structure2tensor(
254
+ atom_coordinates=apo.coords[r_a:],
255
+ atom_types=apo.atom_array.atom_name[r_a:],
256
+ element_types=apo.atom_array.element[r_a:],
257
+ residue_coordinates=apo_calpha.coords[r_a:],
258
+ residue_types=apo_calpha.atom_array.res_name[r_a:],
259
+ residue_ids=apo_calpha.atom_array.res_id[r_a:],
260
+ )
261
+
262
+
263
+
264
+ graph["ligand"].x = apo_l_props["atom_types"]
265
+ graph["ligand"].pos = apo_l_props["atom_coordinates"]
266
+ graph["receptor"].x = apo_r_props["atom_types"]
267
+ graph["receptor"].pos = apo_r_props["atom_coordinates"]
268
+ graph["ligand"].y = holo_l_props["atom_coordinates"]
269
+ # graph["ligand"].pos = holo_l_props["atom_coordinates"]
270
+ graph["receptor"].y = holo_r_props["atom_coordinates"]
271
+ # graph["receptor"].pos = holo_r_props["atom_coordinates"]
272
+ if add_edges and torch_cluster_installed:
273
+ graph["ligand"].edge_index = knn_graph(
274
+ graph["ligand"].pos, k=k
275
+ )
276
+ graph["receptor"].edge_index = knn_graph(
277
+ graph["receptor"].pos, k=k
278
+ )
279
+ # graph["ligand"].edge_index = knn_graph(
280
+ # graph["ligand"].pos, k=k
281
+ # )
282
+ # graph["receptor"].edge_index = knn_graph(
283
+ # graph["receptor"].pos, k=k
284
+ # )
285
+
286
+ return graph
287
+ def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/test_out.pdb', k=5):
288
+ """
289
+ Create a heterogeneous graph from two PDB files, with the ligand and receptor
290
+ as separate nodes, and their respective features and edges.
291
+
292
+ Args:
293
+ pdb1 (str): PDB file path for ligand.
294
+ pdb2 (str): PDB file path for receptor.
295
+ coords3 (list): List of coordinates used for `y` values (e.g., binding affinity, etc.).
296
+ k (int): Number of nearest neighbors for constructing the knn graph.
297
+
298
+ Returns:
299
+ HeteroData: A PyG HeteroData object containing ligand and receptor data.
300
+ """
301
+ # Extract coordinates from PDB files
302
+ coords1 = torch.tensor(extract_coordinates_from_pdb(pdb1),dtype=torch.float)
303
+ coords2 = torch.tensor(extract_coordinates_from_pdb(pdb2),dtype=torch.float)
304
+ coords3 = torch.tensor(extract_coordinates_from_pdb(pdb3),dtype=torch.float)
305
+ # Create the HeteroData object
306
+ data = HeteroData()
307
+
308
+ # Define ligand node features
309
+ data["ligand"].x = torch.tensor(coords1, dtype=torch.float)
310
+ data["ligand"].pos = coords1
311
+ data["ligand"].y = torch.tensor(coords3[:len(coords1)], dtype=torch.float)
312
+
313
+ # Define receptor node features
314
+ data["receptor"].x = torch.tensor(coords2, dtype=torch.float)
315
+ data["receptor"].pos = coords2
316
+ data["receptor"].y = torch.tensor(coords3[len(coords1):], dtype=torch.float)
317
+
318
+ # Construct k-NN graph for ligand
319
+ ligand_edge_index = knn_graph(data["ligand"].pos, k=k)
320
+ data["ligand"].edge_index = ligand_edge_index
321
+
322
+ # Construct k-NN graph for receptor
323
+ receptor_edge_index = knn_graph(data["receptor"].pos, k=k)
324
+ data["receptor"].edge_index = receptor_edge_index
325
+
326
+ # Convert edge index to SparseTensor for ligand
327
+ data["ligand", "ligand"].edge_index = ligand_edge_index
328
+
329
+ # Convert edge index to SparseTensor for receptor
330
+ data["receptor", "receptor"].edge_index = receptor_edge_index
331
+
332
+ return data
333
+
334
+
335
+ def tensor_to_pdb(tensor, pdb_filename="test_out.pdb", chain_id="L"):
336
+ """
337
+ Convert a tensor of coordinates to PDB format, handling an extra dimension if present.
338
+
339
+ Args:
340
+ tensor (torch.Tensor): Tensor of shape (1, N, 3) or (N, 3), where each entry is
341
+ (x, y, z) coordinates for atoms.
342
+ pdb_filename (str): Output filename for the PDB file.
343
+ chain_id (str): Chain identifier for the PDB structure.
344
+ """
345
+ # Remove the first dimension if it’s 1 (e.g., shape is (1, N, 3))
346
+ if tensor.dim() == 3 and tensor.size(0) == 1:
347
+ tensor = tensor.squeeze(0)
348
+
349
+ # Open the PDB file for writing
350
+ with open(pdb_filename, 'w') as pdb_file:
351
+ pdb_file.write("REMARK Generated by tensor_to_pdb function\n")
352
+
353
+ # Iterate over each atom in the tensor
354
+ for atom_idx, (x, y, z) in enumerate(tensor):
355
+ pdb_line = (
356
+ f"ATOM {atom_idx + 1:5d} C LIG {chain_id} {atom_idx + 1:4d} "
357
+ f"{x.item():8.3f}{y.item():8.3f}{z.item():8.3f} 1.00 0.00 C\n"
358
+ )
359
+ pdb_file.write(pdb_line)
360
+
361
+ pdb_file.write("END\n")
362
+ class MPNNLayer(MessagePassing):
363
+ def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
364
+ """Message Passing Neural Network Layer
365
+
366
+ Args:
367
+ emb_dim: (int) - hidden dimension d
368
+ edge_dim: (int) - edge feature dimension d_e
369
+ aggr: (str) - aggregation function \oplus (sum/mean/max)
370
+ """
371
+ # Set the aggregation function
372
+ super().__init__(aggr=aggr)
373
+
374
+ self.emb_dim = emb_dim
375
+ self.edge_dim = edge_dim
376
+
377
+ # MLP \psi for computing messages m_ij
378
+ # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU
379
+ # dims: (2d + d_e) -> d
380
+ self.mlp_msg = Sequential(
381
+ Linear(2*emb_dim + edge_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
382
+ Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
383
+ )
384
+
385
+ # MLP \phi for computing updated node features h_i^{l+1}
386
+ # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU
387
+ # dims: 2d -> d
388
+ self.mlp_upd = Sequential(
389
+ Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
390
+ Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
391
+ )
392
+
393
+ def forward(self, h, edge_index, edge_attr):
394
+ """
395
+ The forward pass updates node features h via one round of message passing.
396
+
397
+ As our MPNNLayer class inherits from the PyG MessagePassing parent class,
398
+ we simply need to call the propagate() function which starts the
399
+ message passing procedure: message() -> aggregate() -> update().
400
+
401
+ The MessagePassing class handles most of the logic for the implementation.
402
+ To build custom GNNs, we only need to define our own message(),
403
+ aggregate(), and update() functions (defined subsequently).
404
+
405
+ Args:
406
+ h: (n, d) - initial node features
407
+ edge_index: (e, 2) - pairs of edges (i, j)
408
+ edge_attr: (e, d_e) - edge features
409
+
410
+ Returns:
411
+ out: (n, d) - updated node features
412
+ """
413
+ out = self.propagate(edge_index, h=h, edge_attr=edge_attr)
414
+ return out
415
+
416
+ def message(self, h_i, h_j, edge_attr):
417
+ """Step (1) Message
418
+
419
+ The message() function constructs messages from source nodes j
420
+ to destination nodes i for each edge (i, j) in edge_index.
421
+
422
+ The arguments can be a bit tricky to understand: message() can take
423
+ any arguments that were initially passed to propagate. Additionally,
424
+ we can differentiate destination nodes and source nodes by appending
425
+ _i or _j to the variable name, e.g. for the node features h, we
426
+ can use h_i and h_j.
427
+
428
+ This part is critical to understand as the message() function
429
+ constructs messages for each edge in the graph. The indexing of the
430
+ original node features h (or other node variables) is handled under
431
+ the hood by PyG.
432
+
433
+ Args:
434
+ h_i: (e, d) - destination node features
435
+ h_j: (e, d) - source node features
436
+ edge_attr: (e, d_e) - edge features
437
+
438
+ Returns:
439
+ msg: (e, d) - messages m_ij passed through MLP \psi
440
+ """
441
+ msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
442
+ return self.mlp_msg(msg)
443
+
444
+ def aggregate(self, inputs, index):
445
+ """Step (2) Aggregate
446
+
447
+ The aggregate function aggregates the messages from neighboring nodes,
448
+ according to the chosen aggregation function ('sum' by default).
449
+
450
+ Args:
451
+ inputs: (e, d) - messages m_ij from destination to source nodes
452
+ index: (e, 1) - list of source nodes for each edge/message in input
453
+
454
+ Returns:
455
+ aggr_out: (n, d) - aggregated messages m_i
456
+ """
457
+ return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
458
+
459
+ def update(self, aggr_out, h):
460
+ """
461
+ Step (3) Update
462
 
463
+ The update() function computes the final node features by combining the
464
+ aggregated messages with the initial node features.
465
 
466
+ update() takes the first argument aggr_out, the result of aggregate(),
467
+ as well as any optional arguments that were initially passed to
468
+ propagate(). E.g. in this case, we additionally pass h.
469
 
470
+ Args:
471
+ aggr_out: (n, d) - aggregated messages m_i
472
+ h: (n, d) - initial node features
473
 
474
+ Returns:
475
+ upd_out: (n, d) - updated node features passed through MLP \phi
476
+ """
477
+ upd_out = torch.cat([h, aggr_out], dim=-1)
478
+ return self.mlp_upd(upd_out)
479
+
480
+ def __repr__(self) -> str:
481
+ return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')
482
+ class MPNNModel(Module):
483
+ def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
484
+ """Message Passing Neural Network model for graph property prediction
485
+
486
+ Args:
487
+ num_layers: (int) - number of message passing layers L
488
+ emb_dim: (int) - hidden dimension d
489
+ in_dim: (int) - initial node feature dimension d_n
490
+ edge_dim: (int) - edge feature dimension d_e
491
+ out_dim: (int) - output dimension (fixed to 1)
492
+ """
493
+ super().__init__()
494
+
495
+ # Linear projection for initial node features
496
+ # dim: d_n -> d
497
+ self.lin_in = Linear(in_dim, emb_dim)
498
+
499
+ # Stack of MPNN layers
500
+ self.convs = torch.nn.ModuleList()
501
+ for layer in range(num_layers):
502
+ self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))
503
+
504
+ # Global pooling/readout function R (mean pooling)
505
+ # PyG handles the underlying logic via global_mean_pool()
506
+ self.pool = global_mean_pool
507
+
508
+ # Linear prediction head
509
+ # dim: d -> out_dim
510
+ self.lin_pred = Linear(emb_dim, out_dim)
511
+
512
+ def forward(self, data):
513
+ """
514
+ Args:
515
+ data: (PyG.Data) - batch of PyG graphs
516
+
517
+ Returns:
518
+ out: (batch_size, out_dim) - prediction for each graph
519
+ """
520
+ h = self.lin_in(data.x) # (n, d_n) -> (n, d)
521
+
522
+ for conv in self.convs:
523
+ h = h + conv(h, data.edge_index, data.edge_attr) # (n, d) -> (n, d)
524
+ # Note that we add a residual connection after each MPNN layer
525
+
526
+ h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)
527
+
528
+ out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)
529
+
530
+ return out.view(-1)
531
+
532
+
533
+ class EquivariantMPNNLayer(MessagePassing):
534
+ def __init__(self, emb_dim=64, aggr='add'):
535
+ """Message Passing Neural Network Layer
536
+
537
+ This layer is equivariant to 3D rotations and translations.
538
+
539
+ Args:
540
+ emb_dim: (int) - hidden dimension d
541
+ edge_dim: (int) - edge feature dimension d_e
542
+ aggr: (str) - aggregation function \oplus (sum/mean/max)
543
+ """
544
+ # Set the aggregation function
545
+ super().__init__(aggr=aggr)
546
+
547
+ self.emb_dim = emb_dim
548
+
549
+
550
+ #
551
+ self.mlp_msg = Sequential(
552
+ Linear(2 * emb_dim + 1, emb_dim),
553
+ BatchNorm1d(emb_dim),
554
+ ReLU(),
555
+ Linear(emb_dim, emb_dim),
556
+ BatchNorm1d(emb_dim),
557
+ ReLU()
558
+ )
559
+
560
+
561
+ self.mlp_pos = Sequential(
562
+ Linear(emb_dim, emb_dim),
563
+ BatchNorm1d(emb_dim),
564
+ ReLU(),
565
+ Linear(emb_dim,1)
566
+ ) # MLP \psi
567
+ self.mlp_upd = Sequential(
568
+ Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim,emb_dim), BatchNorm1d(emb_dim), ReLU()
569
+ ) # MLP \phi
570
+ # ===========================================
571
+
572
+ def forward(self, h, pos, edge_index):
573
+ """
574
+ The forward pass updates node features h via one round of message passing.
575
+
576
+ Args:
577
+ h: (n, d) - initial node features
578
+ pos: (n, 3) - initial node coordinates
579
+ edge_index: (e, 2) - pairs of edges (i, j)
580
+ edge_attr: (e, d_e) - edge features
581
+
582
+ Returns:
583
+ out: [(n, d),(n,3)] - updated node features
584
+ """
585
+
586
+ #
587
+ out = self.propagate(edge_index=edge_index, h=h, pos=pos)
588
+ return out
589
+ # ==========================================
590
+
591
+
592
+ #
593
+ def message(self, h_i,h_j,pos_i,pos_j):
594
+ # Compute distance between nodes i and j (Euclidean distance)
595
+ #distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1)
596
+ pos_diff = pos_i - pos_j
597
+ dists = torch.norm(pos_diff,dim=-1).unsqueeze(1)
598
+
599
+ # Concatenate node features, edge features, and distance
600
+ msg = torch.cat([h_i , h_j, dists], dim=-1)
601
+ msg = self.mlp_msg(msg)
602
+ pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1)
603
+
604
+
605
+ # (e, d)
606
+ return msg , pos_diff
607
+ # ...
608
+ #
609
+ def aggregate(self, inputs, index):
610
+ """The aggregate function aggregates the messages from neighboring nodes,
611
+ according to the chosen aggregation function ('sum' by default).
612
+
613
+ Args:
614
+ inputs: (e, d) - messages m_ij from destination to source nodes
615
+ index: (e, 1) - list of source nodes for each edge/message in input
616
+
617
+ Returns:
618
+ aggr_out: (n, d) - aggregated messages m_i
619
+ """
620
+ msgs , pos_diffs = inputs
621
+
622
+ msg_aggr = scatter(msgs, index , dim = self.node_dim , reduce = self.aggr)
623
+
624
+ pos_aggr = scatter(pos_diffs, index, dim = self.node_dim , reduce = "mean")
625
+
626
+
627
+ return msg_aggr , pos_aggr
628
+
629
+ def update(self, aggr_out, h , pos):
630
+ msg_aggr , pos_aggr = aggr_out
631
+
632
+ upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1))
633
+
634
+ upd_pos = pos + pos_aggr
635
+
636
+ return upd_out , upd_pos
637
+
638
+
639
+ def __repr__(self) -> str:
640
+ return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')
641
+
642
+ class FinalMPNNModel(MPNNModel):
643
+ def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2):
644
+ """Message Passing Neural Network model for graph property prediction
645
+
646
+ This model uses both node features and coordinates as inputs, and
647
+ is invariant to 3D rotations and translations (the constituent MPNN layers
648
+ are equivariant to 3D rotations and translations).
649
+
650
+ Args:
651
+ num_layers: (int) - number of message passing layers L
652
+ emb_dim: (int) - hidden dimension d
653
+ in_dim: (int) - initial node feature dimension d_n
654
+ edge_dim: (int) - edge feature dimension d_e
655
+ out_dim: (int) - output dimension (fixed to 1)
656
+ """
657
+ super().__init__()
658
+
659
+ # Linear projection for initial node features
660
+ # dim: d_n -> d
661
+ self.lin_in = Linear(in_dim, emb_dim)
662
+ self.equiv_layer = EquivariantMPNNLayer(emb_dim=emb_dim)
663
+ # Stack of MPNN layers
664
+ self.convs = torch.nn.ModuleList()
665
+ for layer in range(num_layers):
666
+ self.convs.append(EquivariantMPNNLayer(emb_dim, aggr='add'))
667
+
668
+
669
+ self.cross_attention = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
670
+ self.fc_rotation = nn.Linear(emb_dim, 9)
671
+ self.fc_translation = nn.Linear(emb_dim, 3)
672
+ # Global pooling/readout function R (mean pooling)
673
+ # PyG handles the underlying logic via global_mean_pool()
674
+ # self.pool = global_mean_pool
675
+
676
+ def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index):
677
+ """
678
+ Processes a single receptor-ligand pair.
679
+
680
+ Args:
681
+ receptor: Tensor of shape (1, num_receptor_atoms, 3) (receptor coordinates)
682
+ ligand: Tensor of shape (1, num_ligand_atoms, 3) (ligand coordinates)
683
+
684
+ Returns:
685
+ rotation_matrix: Tensor of shape (1, 3, 3) predicted rotation matrix for the ligand.
686
+ translation_vector: Tensor of shape (1, 3) predicted translation vector for the ligand.
687
+
688
+ """
689
+
690
+
691
+ # h_receptor = receptor # Initial node features for the receptor
692
+ # h_ligand = ligand
693
+ h_receptor = self.lin_in(receptor)
694
+ h_ligand = self.lin_in(ligand) # Initial node features for the ligand
695
+ pos_receptor = receptor # Initial positions
696
+ pos_ligand = ligand
697
+
698
+ for layer in self.convs:
699
+ # Apply the equivariant message-passing layer for both receptor and ligand
700
+ h_receptor, pos_receptor = layer(h_receptor, pos_receptor,receptor_edge_index )
701
+ h_ligand, pos_ligand = layer(h_ligand, pos_ligand, ligand_edge_index)
702
+ # print("Shape of h_receptor:", h_receptor.shape)
703
+ # print("Shape of h_ligand:", h_ligand.shape)
704
+ # Pass the layer outputs through MLPs for embeddings
705
+ emb_features_receptor = h_receptor
706
+ emb_features_ligand = h_ligand
707
+
708
+ attn_output, _ = self.cross_attention(emb_features_receptor, emb_features_ligand, emb_features_ligand)
709
+ rotation_matrix = self.fc_rotation(attn_output.mean(dim=0))
710
+ rotation_matrix = rotation_matrix.view(-1, 3, 3)
711
+ translation_vector = self.fc_translation(attn_output.mean(dim=0))
712
+ return rotation_matrix, translation_vector
713
+
714
+
715
+
716
+
717
+ def forward(self, data):
718
+ """
719
+ The main forward pass of the model.
720
+
721
+ Args:
722
+ batch: Same as in forward_rot_trans.
723
+
724
+ Returns:
725
+ transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
726
+ representing the transformed ligand coordinates after applying the predicted
727
+ rotation and translation.
728
+ """
729
+ receptor = data['receptor']['pos']
730
+ ligand = data['ligand']['pos']
731
+ receptor_edge_index = data['receptor']['edge_index']
732
+ ligand_edge_index = data['ligand']['edge_index']
733
+
734
+ rotation_matrix, translation_vector = self.naive_single(receptor, ligand,receptor_edge_index , ligand_edge_index)
735
+ # for i in range(len(ligands)):
736
+ # ligands[i] = ligands[i] @ rotation_matrix[i] + translation_vector[i]
737
+ ligands = data['ligand']['pos'] @ rotation_matrix + translation_vector
738
+ return ligands
739
+
740
+ class FinalMPNNModelight(pl.LightningModule):
741
+ def __init__(self, num_layers=4, emb_dim=32, in_dim=3, num_heads=1, lr=1e-4):
742
+ super().__init__()
743
+
744
+ self.lin_in = nn.Linear(in_dim, emb_dim)
745
+ self.convs = nn.ModuleList([EquivariantMPNNLayer(emb_dim, aggr='add') for _ in range(num_layers)])
746
+ self.cross_attention = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
747
+ self.fc_rotation = nn.Linear(emb_dim, 9)
748
+ self.fc_translation = nn.Linear(emb_dim, 3)
749
+ self.lr = lr
750
+
751
+
752
+ def naive_single(self, receptor, ligand, receptor_edge_index, ligand_edge_index):
753
+ h_receptor = self.lin_in(receptor)
754
+ h_ligand = self.lin_in(ligand)
755
+ pos_receptor, pos_ligand = receptor, ligand
756
+
757
+ for layer in self.convs:
758
+ h_receptor, pos_receptor = layer(h_receptor, pos_receptor, receptor_edge_index)
759
+ h_ligand, pos_ligand = layer(h_ligand, pos_ligand, ligand_edge_index)
760
+
761
+ attn_output, _ = self.cross_attention(h_receptor, h_ligand, h_ligand)
762
+ rotation_matrix = self.fc_rotation(attn_output.mean(dim=0)).view(-1, 3, 3)
763
+ translation_vector = self.fc_translation(attn_output.mean(dim=0))
764
+ return rotation_matrix, translation_vector
765
+
766
+ def forward(self, data):
767
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
768
+ receptor = data['receptor']['pos'].to(device)
769
+ ligand = data['ligand']['pos'].to(device)
770
+ receptor_edge_index = data['receptor', 'receptor']['edge_index'].to(device)
771
+ ligand_edge_index = data['ligand', 'ligand']['edge_index'].to(device)
772
+
773
+ rotation_matrix, translation_vector = self.naive_single(receptor, ligand, receptor_edge_index, ligand_edge_index)
774
+ transformed_ligand = torch.matmul(ligand ,rotation_matrix) + translation_vector
775
+ return transformed_ligand
776
+
777
+
778
+ def training_step(self, batch, batch_idx):
779
+ ligand_pred = self(batch)
780
+ ligand_true = batch['ligand']['y']
781
+ loss = F.mse_loss(ligand_pred.squeeze(0), ligand_true)
782
+ self.log('train_loss', loss, batch_size=8)
783
+ return loss
784
+
785
+
786
+ def validation_step(self, batch, batch_idx):
787
+ ligand_pred = self(batch)
788
+ ligand_true = batch['ligand']['y']
789
+ loss = F.l1_loss(ligand_pred.squeeze(0), ligand_true)
790
+
791
+ self.log('val_loss', loss, prog_bar=True, batch_size=8)
792
+
793
+ return loss
794
+
795
+
796
+ def test_step(self, batch, batch_idx):
797
+ ligand_pred = self(batch)
798
+ ligand_true = batch['ligand']['y']
799
+ loss = F.l1_loss(ligand_pred.squeeze(0), ligand_true)
800
+ self.log('test_loss', loss, prog_bar=True, batch_size=8)
801
+ return loss
802
+
803
+ def configure_optimizers(self):
804
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
805
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
806
+ optimizer, mode="min", factor=0.1, patience=5
807
+ )
808
+ return {
809
+ "optimizer": optimizer,
810
+ "lr_scheduler": {
811
+ "scheduler": scheduler,
812
+ "monitor": "val_loss", # Monitor validation loss to adjust the learning rate
813
+ },
814
+ }
815
+
816
+ model_path = "/home/sukanya/iitm_bisect_pinder_submission/EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt"
817
+ model = FinalMPNNModelight.load_from_checkpoint(model_path)
818
+ trainer = pl.Trainer(
819
+
820
+
821
+ fast_dev_run=False,
822
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
823
+ precision="bf16-mixed",
824
+
825
+ devices=1,
826
+ )
827
+ model.eval()
828
  def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2, input_protein_2):
829
  start_time = time.time()
830
+ data = create_graph(input_protein_1, input_protein_2, '/home/sukanya/iitm_bisect_pinder_submission/test_out.pdb', k=10)
831
+
832
+ with torch.no_grad():
833
+ output = model(data)
834
+ file = tensor_to_pdb(output)
835
  # return an output pdb file with the protein and two chains A and B.
836
  # also return a JSON with any metrics you want to report
837
  metrics = {"mean_plddt": 80, "binding_affinity": 2}
838
  end_time = time.time()
839
  run_time = end_time - start_time
840
+ return file,json.dumps(metrics), run_time
841
 
842
  with gr.Blocks() as app:
843
 
844
  gr.Markdown("# Template for inference")
845
 
846
+ gr.Markdown("EquiMPNN MOdel")
847
  with gr.Row():
848
  with gr.Column():
849
  input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
 
914
  btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2], outputs=[out, metrics, run_time])
915
 
916
  app.launch()
917
+
lightning_logs/version_0/hparams.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ emb_dim: 32
2
+ in_dim: 3
3
+ lr: 0.0001
4
+ num_heads: 1
5
+ num_layers: 4
requirements.txt CHANGED
@@ -1,2 +1,216 @@
1
- gradio
2
- gradio_molecule3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiofiles==23.2.aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.10
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.3.0
10
+ asttokens==2.4.1
11
+ async-lru==2.0.4
12
+ async-timeout==4.0.3
13
+ attrs==24.2.0
14
+ babel==2.16.0
15
+ beautifulsoup4==4.12.3
16
+ bio==1.7.1
17
+ biopython==1.84
18
+ biothings-client==0.3.1
19
+ biotite==0.41.2
20
+ bleach==6.2.0
21
+ cachetools==5.5.0
22
+ certifi==2024.8.30
23
+ cffi==1.17.1
24
+ charset-normalizer==3.4.0
25
+ click==8.1.7
26
+ comm==0.2.2
27
+ debugpy==1.8.7
28
+ decorator==5.1.1
29
+ defusedxml==0.7.1
30
+ docker-pycreds==0.4.0
31
+ exceptiongroup==1.2.2
32
+ executing==2.1.0
33
+ fastapi==0.115.4
34
+ fastjsonschema==2.20.0
35
+ fastpdb==1.3.1
36
+ ffmpy==0.4.0
37
+ filelock==3.16.1
38
+ fqdn==1.5.1
39
+ frozenlist==1.5.0
40
+ fsspec==2024.10.0
41
+ gcsfs==2024.10.0
42
+ gitdb==4.0.11
43
+ GitPython==3.1.43
44
+ google-api-core==2.22.0
45
+ google-auth==2.35.0
46
+ google-auth-oauthlib==1.2.1
47
+ google-cloud-core==2.4.1
48
+ google-cloud-storage==2.18.2
49
+ google-crc32c==1.6.0
50
+ google-resumable-media==2.7.2
51
+ googleapis-common-protos==1.65.0
52
+ gprofiler-official==1.0.0
53
+ gradio==5.5.0
54
+ gradio_client==1.4.2
55
+ gradio_molecule3d==0.0.6
56
+ grpcio==1.67.1
57
+ h11==0.14.0
58
+ httpcore==1.0.6
59
+ httpx==0.27.2
60
+ huggingface-hub==0.26.2
61
+ idna==3.10
62
+ ipykernel==6.29.5
63
+ ipython==8.29.0
64
+ ipywidgets==8.1.5
65
+ isoduration==20.11.0
66
+ jedi==0.19.1
67
+ Jinja2==3.1.4
68
+ joblib==1.4.2
69
+ json5==0.9.25
70
+ jsonpointer==3.0.0
71
+ jsonschema==4.23.0
72
+ jsonschema-specifications==2024.10.1
73
+ jupyter-events==0.10.0
74
+ jupyter-lsp==2.2.5
75
+ jupyter_client==8.6.3
76
+ jupyter_core==5.7.2
77
+ jupyter_server==2.14.2
78
+ jupyter_server_terminals==0.5.3
79
+ jupyterlab==4.3.0
80
+ jupyterlab_pygments==0.3.0
81
+ jupyterlab_server==2.27.3
82
+ jupyterlab_widgets==3.0.13
83
+ lightning==2.4.0
84
+ lightning-utilities==0.11.8
85
+ Markdown==3.7
86
+ markdown-it-py==3.0.0
87
+ MarkupSafe==2.1.5
88
+ matplotlib-inline==0.1.7
89
+ mdurl==0.1.2
90
+ mistune==3.0.2
91
+ mpmath==1.3.0
92
+ msgpack==1.1.0
93
+ multidict==6.1.0
94
+ mygene==3.2.2
95
+ nbclient==0.10.0
96
+ nbconvert==7.16.4
97
+ nbformat==5.10.4
98
+ nest-asyncio==1.6.0
99
+ networkx==3.4.2
100
+ notebook_shim==0.2.4
101
+ numpy==1.26.4
102
+ nvidia-cublas-cu12==12.4.5.8
103
+ nvidia-cuda-cupti-cu12==12.4.127
104
+ nvidia-cuda-nvrtc-cu12==12.4.127
105
+ nvidia-cuda-runtime-cu12==12.4.127
106
+ nvidia-cudnn-cu12==9.1.0.70
107
+ nvidia-cufft-cu12==11.2.1.3
108
+ nvidia-curand-cu12==10.3.5.147
109
+ nvidia-cusolver-cu12==11.6.1.9
110
+ nvidia-cusparse-cu12==12.3.1.170
111
+ nvidia-nccl-cu12==2.21.5
112
+ nvidia-nvjitlink-cu12==12.4.127
113
+ nvidia-nvtx-cu12==12.4.127
114
+ oauthlib==3.2.2
115
+ orjson==3.10.11
116
+ overrides==7.7.0
117
+ packaging==24.1
118
+ pandas==2.2.3
119
+ pandocfilters==1.5.1
120
+ parso==0.8.4
121
+ pexpect==4.9.0
122
+ pillow==11.0.0
123
+ pinder==0.4.1
124
+ platformdirs==4.3.6
125
+ plotly==5.24.1
126
+ pooch==1.8.2
127
+ prometheus_client==0.21.0
128
+ prompt_toolkit==3.0.48
129
+ propcache==0.2.0
130
+ proto-plus==1.25.0
131
+ protobuf==5.28.3
132
+ psutil==6.1.0
133
+ ptyprocess==0.7.0
134
+ pure_eval==0.2.3
135
+ pyarrow==18.0.0
136
+ pyasn1==0.6.1
137
+ pyasn1_modules==0.4.1
138
+ pycparser==2.22
139
+ pydantic==2.9.2
140
+ pydantic_core==2.23.4
141
+ pydub==0.25.1
142
+ pyg-lib==0.4.0+pt24cu124
143
+ Pygments==2.18.0
144
+ pyparsing==3.2.0
145
+ python-dateutil==2.9.0.post0
146
+ python-dotenv==1.0.1
147
+ python-json-logger==2.0.7
148
+ python-multipart==0.0.12
149
+ pytorch-lightning==2.4.0
150
+ pytz==2024.2
151
+ PyYAML==6.0.2
152
+ pyzmq==26.2.0
153
+ referencing==0.35.1
154
+ requests==2.32.3
155
+ requests-oauthlib==2.0.0
156
+ rfc3339-validator==0.1.4
157
+ rfc3986-validator==0.1.1
158
+ rich==13.9.4
159
+ rootutils==1.0.7
160
+ rpds-py==0.20.1
161
+ rsa==4.9
162
+ ruff==0.7.2
163
+ safehttpx==0.1.1
164
+ scikit-learn==1.5.2
165
+ scipy==1.14.1
166
+ semantic-version==2.10.0
167
+ Send2Trash==1.8.3
168
+ sentry-sdk==2.18.0
169
+ setproctitle==1.3.3
170
+ shellingham==1.5.4
171
+ six==1.16.0
172
+ smmap==5.0.1
173
+ sniffio==1.3.1
174
+ soupsieve==2.6
175
+ stack-data==0.6.3
176
+ starlette==0.41.2
177
+ sympy==1.13.1
178
+ tabulate==0.9.0
179
+ tenacity==9.0.0
180
+ tensorboard==2.18.0
181
+ tensorboard-data-server==0.7.2
182
+ tensorboardX==2.6.2.2
183
+ terminado==0.18.1
184
+ threadpoolctl==3.5.0
185
+ tinycss2==1.4.0
186
+ tomli==2.0.2
187
+ tomlkit==0.12.0
188
+ torch==2.5.1
189
+ torch-geometric==2.6.1
190
+ torch_cluster==1.6.3+pt24cu124
191
+ torch_scatter==2.1.2+pt24cu124
192
+ torch_sparse==0.6.18+pt24cu124
193
+ torch_spline_conv==1.2.2+pt24cu124
194
+ torchmetrics==1.5.1
195
+ torchtyping==0.1.5
196
+ tornado==6.4.1
197
+ tqdm==4.66.6
198
+ traitlets==5.14.3
199
+ triton==3.1.0
200
+ typeguard==2.13.3
201
+ typer==0.13.0
202
+ types-python-dateutil==2.9.0.20241003
203
+ typing_extensions==4.12.2
204
+ tzdata==2024.2
205
+ uri-template==1.3.0
206
+ urllib3==2.2.3
207
+ uvicorn==0.32.0
208
+ wandb==0.18.5
209
+ wcwidth==0.2.13
210
+ webcolors==24.8.0
211
+ webencodings==0.5.1
212
+ websocket-client==1.8.0
213
+ websockets==12.0
214
+ Werkzeug==3.1.2
215
+ widgetsnbextension==4.0.13
216
+ yarl==1.17.1