habdine commited on
Commit
1fcde9f
1 Parent(s): b010ab3

Upload code

Browse files
Files changed (8) hide show
  1. configuration_prot2text.py +74 -0
  2. conversion.py +473 -0
  3. graphs.py +1144 -0
  4. modeling_prot2text.py +398 -0
  5. pdb2graph.py +178 -0
  6. utils.py +745 -0
  7. utils_convert.py +82 -0
  8. utils_dataset.py +60 -0
configuration_prot2text.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Prot2Text configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers import AutoConfig
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class Prot2TextConfig(PretrainedConfig):
12
+ model_type = "prot2text"
13
+ keys_to_ignore_at_inference = ["past_key_values"]
14
+ _keys_to_ignore_on_load_missing = [r"transformer"]
15
+
16
+ def __init__(
17
+ self,
18
+ cross_esm_graph=True,
19
+ decoder_start_token_id=50257,
20
+ early_stopping=True,
21
+ eos_token_id=50258,
22
+ bos_token_id=50257,
23
+ esm=True,
24
+ esm_model_name="facebook/esm2_t6_8M_UR50D",
25
+ gpt_model_name="gpt2",
26
+ length_penalty=2.0,
27
+ max_new_tokens=256,
28
+ no_repeat_ngram_size=3,
29
+ pad_token_id=50256,
30
+ prot2text_version="1.1",
31
+ rgcn=True,
32
+ rgc_input_dim=67,
33
+ rgcn_n_layers=6,
34
+ gpt_config=None,
35
+ esm_config=None,
36
+ **kwargs,
37
+ ):
38
+ self.cross_esm_graph = cross_esm_graph
39
+ self.decoder_start_token_id = decoder_start_token_id
40
+ self.early_stopping = early_stopping
41
+ self.eos_token_id = eos_token_id
42
+ self.esm = esm
43
+ self.esm_model_name = esm_model_name
44
+ self.gpt_model_name = gpt_model_name
45
+ self.length_penalty = length_penalty
46
+ self.max_new_tokens = max_new_tokens
47
+ self.no_repeat_ngram_size = no_repeat_ngram_size
48
+ self.pad_token_id = pad_token_id
49
+ self.prot2text_version = prot2text_version
50
+ self.rgcn = rgcn
51
+ self.rgc_input_dim = rgc_input_dim
52
+ self.rgcn_n_layers = rgcn_n_layers
53
+ if gpt_config is None:
54
+ self.gpt_config = AutoConfig.from_pretrained(gpt_model_name,
55
+ _name_or_path= gpt_model_name,
56
+ is_encoder_decoder=True,
57
+ use_cache=False,
58
+ add_cross_attention=True,
59
+ bos_token_id=bos_token_id,
60
+ decoder_start_token_id=decoder_start_token_id,
61
+ eos_token_id=eos_token_id,
62
+ max_new_tokens=max_new_tokens,
63
+ pad_token_id=50256,
64
+ vocab_size=50259,
65
+ num_beams=1,
66
+ max_length=256,
67
+ min_length=1).to_dict()
68
+ else:
69
+ self.gpt_config = gpt_config
70
+ if esm_config is None:
71
+ self.esm_config = AutoConfig.from_pretrained(esm_model_name).to_dict()
72
+ self.esm_config = esm_config
73
+
74
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
conversion.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for converting Graphein Networks to Geometric Deep Learning formats.
2
+ """
3
+ # %%
4
+ # Graphein
5
+ # Author: Kexin Huang, Arian Jamasb <[email protected]>
6
+ # License: MIT
7
+ # Project Website: https://github.com/a-r-j/graphein
8
+ # Code Repository: https://github.com/a-r-j/graphein
9
+ from __future__ import annotations
10
+
11
+ from typing import List, Optional
12
+
13
+ import networkx as nx
14
+ import numpy as np
15
+ import torch
16
+
17
+ try:
18
+ from graphein.utils.dependencies import import_message
19
+ except ImportError:
20
+ raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')
21
+
22
+ try:
23
+ import torch_geometric
24
+ from torch_geometric.data import Data
25
+ except ImportError:
26
+ import_message(
27
+ submodule="graphein.ml.conversion",
28
+ package="torch_geometric",
29
+ pip_install=True,
30
+ conda_channel="rusty1s",
31
+ )
32
+
33
+ try:
34
+ import dgl
35
+ except ImportError:
36
+ import_message(
37
+ submodule="graphein.ml.conversion",
38
+ package="dgl",
39
+ pip_install=True,
40
+ conda_channel="dglteam",
41
+ )
42
+
43
+ try:
44
+ import jax.numpy as jnp
45
+ except ImportError:
46
+ import_message(
47
+ submodule="graphein.ml.conversion",
48
+ package="jax",
49
+ pip_install=True,
50
+ conda_channel="conda-forge",
51
+ )
52
+ try:
53
+ import jraph
54
+ except ImportError:
55
+ import_message(
56
+ submodule="graphein.ml.conversion",
57
+ package="jraph",
58
+ pip_install=True,
59
+ conda_channel="conda-forge",
60
+ )
61
+
62
+
63
+ SUPPORTED_FORMATS = ["nx", "pyg", "dgl", "jraph"]
64
+ """Supported conversion formats.
65
+
66
+ ``"nx"``: NetworkX graph
67
+
68
+ ``"pyg"``: PyTorch Geometric Data object
69
+
70
+ ``"dgl"``: DGL graph
71
+
72
+ ``"Jraph"``: Jraph GraphsTuple
73
+ """
74
+
75
+ SUPPORTED_VERBOSITY = ["gnn", "default", "all_info"]
76
+ """Supported verbosity levels for preserving graph features in conversion."""
77
+
78
+
79
+ class GraphFormatConvertor:
80
+ """
81
+ Provides conversion utilities between NetworkX Graphs and geometric deep learning library destination formats.
82
+ Currently, we provide support for converstion from ``nx.Graph`` to ``dgl.DGLGraph`` and ``pytorch_geometric.Data``. Supported conversion
83
+ formats can be retrieved from :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`.
84
+
85
+ :param src_format: The type of graph you'd like to convert from. Supported formats are available in :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`
86
+ :type src_format: Literal["nx", "pyg", "dgl", "jraph"]
87
+ :param dst_format: The type of graph format you'd like to convert to. Supported formats are available in:
88
+ ``graphein.ml.conversion.SUPPORTED_FORMATS``
89
+ :type dst_format: Literal["nx", "pyg", "dgl", "jraph"]
90
+ :param verbose: Select from ``"gnn"``, ``"default"``, ``"all_info"`` to determine how much information is preserved (features)
91
+ as some are unsupported by various downstream frameworks
92
+ :type verbose: graphein.ml.conversion.SUPPORTED_VERBOSITY
93
+ :param columns: List of columns in the node features to retain
94
+ :type columns: List[str], optional
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ src_format: str,
100
+ dst_format: str,
101
+ verbose: SUPPORTED_VERBOSITY = "gnn",
102
+ columns: Optional[List[str]] = None,
103
+ ):
104
+ if (src_format not in SUPPORTED_FORMATS) or (
105
+ dst_format not in SUPPORTED_FORMATS
106
+ ):
107
+ raise ValueError(
108
+ "Please specify from supported format, "
109
+ + "/".join(SUPPORTED_FORMATS)
110
+ )
111
+ self.src_format = src_format
112
+ self.dst_format = dst_format
113
+
114
+ # supported_verbose_format = ["gnn", "default", "all_info"]
115
+ if (columns is None) and (verbose not in SUPPORTED_VERBOSITY):
116
+ raise ValueError(
117
+ "Please specify the supported verbose mode ("
118
+ + "/".join(SUPPORTED_VERBOSITY)
119
+ + ") or specify column names!"
120
+ )
121
+
122
+ if columns is None:
123
+ if verbose == "gnn":
124
+ columns = [
125
+ "edge_index",
126
+ "coords",
127
+ "dist_mat",
128
+ "name",
129
+ "node_id",
130
+ ]
131
+ elif verbose == "default":
132
+ columns = [
133
+ "b_factor",
134
+ "chain_id",
135
+ "coords",
136
+ "dist_mat",
137
+ "edge_index",
138
+ "kind",
139
+ "name",
140
+ "node_id",
141
+ "residue_name",
142
+ ]
143
+ elif verbose == "all_info":
144
+ columns = [
145
+ "atom_type",
146
+ "b_factor",
147
+ "chain_id",
148
+ "chain_ids",
149
+ "config",
150
+ "coords",
151
+ "dist_mat",
152
+ "edge_index",
153
+ "element_symbol",
154
+ "kind",
155
+ "name",
156
+ "node_id",
157
+ "node_type",
158
+ "pdb_df",
159
+ "raw_pdb_df",
160
+ "residue_name",
161
+ "residue_number",
162
+ "rgroup_df",
163
+ "sequence_A",
164
+ "sequence_B",
165
+ ]
166
+ self.columns = columns
167
+
168
+ self.type2form = {
169
+ "atom_type": "str",
170
+ "b_factor": "float",
171
+ "chain_id": "str",
172
+ "coords": "np.array",
173
+ "dist_mat": "np.array",
174
+ "element_symbol": "str",
175
+ "node_id": "str",
176
+ "residue_name": "str",
177
+ "residue_number": "int",
178
+ "edge_index": "torch.tensor",
179
+ "kind": "str",
180
+ }
181
+
182
+ def convert_nx_to_dgl(self, G: nx.Graph) -> dgl.DGLGraph:
183
+ """
184
+ Converts ``NetworkX`` graph to ``DGL``
185
+
186
+ :param G: ``nx.Graph`` to convert to ``DGLGraph``
187
+ :type G: nx.Graph
188
+ :return: ``DGLGraph`` object version of input ``NetworkX`` graph
189
+ :rtype: dgl.DGLGraph
190
+ """
191
+ g = dgl.DGLGraph()
192
+ node_id = list(G.nodes())
193
+ G = nx.convert_node_labels_to_integers(G)
194
+
195
+ ## add node level feat
196
+
197
+ node_dict = {}
198
+ for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
199
+ for key, value in feat_dict.items():
200
+ if str(key) in self.columns:
201
+ node_dict[str(key)] = (
202
+ [value] if i == 0 else node_dict[str(key)] + [value]
203
+ )
204
+
205
+ string_dict = {}
206
+ node_dict_transformed = {}
207
+ for i, j in node_dict.items():
208
+ if i == "coords":
209
+ node_dict_transformed[i] = torch.Tensor(np.asarray(j)).type(
210
+ "torch.FloatTensor"
211
+ )
212
+ elif i == "dist_mat":
213
+ node_dict_transformed[i] = torch.Tensor(
214
+ np.asarray(j[0].values)
215
+ ).type("torch.FloatTensor")
216
+ elif self.type2form[i] == "str":
217
+ string_dict[i] = j
218
+ elif self.type2form[i] in ["float", "int"]:
219
+ node_dict_transformed[i] = torch.Tensor(np.array(j))
220
+ g.add_nodes(
221
+ len(node_id),
222
+ node_dict_transformed,
223
+ )
224
+
225
+ edge_dict = {}
226
+ edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
227
+
228
+ # add edge level features
229
+ for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
230
+ for key, value in feat_dict.items():
231
+ if str(key) in self.columns:
232
+ edge_dict[str(key)] = (
233
+ list(value)
234
+ if i == 0
235
+ else edge_dict[str(key)] + list(value)
236
+ )
237
+
238
+ edge_transform_dict = {}
239
+ for i, j in node_dict.items():
240
+ if self.type2form[i] == "str":
241
+ string_dict[i] = j
242
+ elif self.type2form[i] in ["float", "int"]:
243
+ edge_transform_dict[i] = torch.Tensor(np.array(j))
244
+ g.add_edges(edge_index[0], edge_index[1], edge_transform_dict)
245
+
246
+ # add graph level features
247
+ graph_dict = {
248
+ str(feat_name): [G.graph[feat_name]]
249
+ for feat_name in G.graph
250
+ if str(feat_name) in self.columns
251
+ }
252
+
253
+ return g
254
+
255
+ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
256
+ """
257
+ Converts ``NetworkX`` graph to ``pytorch_geometric.data.Data`` object. Requires ``PyTorch Geometric`` (https://pytorch-geometric.readthedocs.io/en/latest/) to be installed.
258
+
259
+ :param G: ``nx.Graph`` to convert to PyTorch Geometric ``Data`` object
260
+ :type G: nx.Graph
261
+ :return: ``Data`` object containing networkx graph data
262
+ :rtype: pytorch_geometric.data.Data
263
+ """
264
+
265
+ # Initialise dict used to construct Data object & Assign node ids as a feature
266
+ data = {"node_id": list(G.nodes())}
267
+ G = nx.convert_node_labels_to_integers(G)
268
+
269
+ # Construct Edge Index
270
+ edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
271
+
272
+ # Add node features
273
+ for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
274
+ for key, value in feat_dict.items():
275
+ if str(key) in self.columns:
276
+ data[str(key)] = (
277
+ [value] if i == 0 else data[str(key)] + [value]
278
+ )
279
+
280
+ # Add edge features
281
+ for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
282
+ for key, value in feat_dict.items():
283
+ if str(key) in self.columns:
284
+ data[str(key)] = (
285
+ list(value) if i == 0 else data[str(key)] + list(value)
286
+ )
287
+
288
+ # Add graph-level features
289
+ for feat_name in G.graph:
290
+ if str(feat_name) in self.columns:
291
+ data[str(feat_name)] = [G.graph[feat_name]]
292
+
293
+ if "edge_index" in self.columns:
294
+ data["edge_index"] = edge_index.view(2, -1)
295
+
296
+ data = Data.from_dict(data)
297
+ data.num_nodes = G.number_of_nodes()
298
+ return data
299
+
300
+ @staticmethod
301
+ def convert_nx_to_nx(G: nx.Graph) -> nx.Graph:
302
+ """
303
+ Converts NetworkX graph (``nx.Graph``) to NetworkX graph (``nx.Graph``) object. Redundant - returns itself.
304
+
305
+ :param G: NetworkX Graph
306
+ :type G: nx.Graph
307
+ :return: NetworkX Graph
308
+ :rtype: nx.Graph
309
+ """
310
+ return G
311
+
312
+ @staticmethod
313
+ def convert_dgl_to_nx(G: dgl.DGLGraph) -> nx.Graph:
314
+ """
315
+ Converts a DGL Graph (``dgl.DGLGraph``) to a NetworkX (``nx.Graph``) object. Preserves node and edge attributes.
316
+
317
+ :param G: ``dgl.DGLGraph`` to convert to ``NetworkX`` graph.
318
+ :type G: dgl.DGLGraph
319
+ :return: NetworkX graph object.
320
+ :rtype: nx.Graph
321
+ """
322
+ node_attrs = G.node_attr_schemes().keys()
323
+ edge_attrs = G.edge_attr_schemes().keys()
324
+ return dgl.to_networkx(G, node_attrs, edge_attrs)
325
+
326
+ @staticmethod
327
+ def convert_pyg_to_nx(G: Data) -> nx.Graph:
328
+ """Converts PyTorch Geometric ``Data`` object to NetworkX graph (``nx.Graph``).
329
+
330
+ :param G: Pytorch Geometric Data.
331
+ :type G: torch_geometric.data.Data
332
+ :returns: NetworkX graph.
333
+ :rtype: nx.Graph
334
+ """
335
+ return torch_geometric.utils.to_networkx(G)
336
+
337
+ def convert_nx_to_jraph(self, G: nx.Graph) -> jraph.GraphsTuple:
338
+ """Converts NetworkX graph (``nx.Graph``) to Jraph GraphsTuple graph. Requires ``jax`` and ``Jraph``.
339
+
340
+ :param G: Networkx graph to convert.
341
+ :type G: nx.Graph
342
+ :return: Jraph GraphsTuple graph.
343
+ :rtype: jraph.GraphsTuple
344
+ """
345
+ G = nx.convert_node_labels_to_integers(G)
346
+
347
+ n_node = len(G)
348
+ n_edge = G.number_of_edges()
349
+ edge_list = list(G.edges())
350
+ senders, receivers = zip(*edge_list)
351
+ senders, receivers = jnp.array(senders), jnp.array(receivers)
352
+
353
+ # Add node features
354
+ node_features = {}
355
+ for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
356
+ for key, value in feat_dict.items():
357
+ if str(key) in self.columns:
358
+ # node_features[str(key)] = (
359
+ # [value]
360
+ # if i == 0
361
+ # else node_features[str(key)] + [value]
362
+ # )
363
+ feat = (
364
+ [value]
365
+ if i == 0
366
+ else node_features[str(key)] + [value]
367
+ )
368
+ try:
369
+ feat = torch.tensor(feat)
370
+ node_features[str(key)] = feat
371
+ except TypeError:
372
+ node_features[str(key)] = feat
373
+
374
+ # Add edge features
375
+ edge_features = {}
376
+ for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
377
+ for key, value in feat_dict.items():
378
+ if str(key) in self.columns:
379
+ edge_features[str(key)] = (
380
+ list(value)
381
+ if i == 0
382
+ else edge_features[str(key)] + list(value)
383
+ )
384
+
385
+ # Add graph features
386
+ global_context = {
387
+ str(feat_name): [G.graph[feat_name]]
388
+ for feat_name in G.graph
389
+ if str(feat_name) in self.columns
390
+ }
391
+
392
+ return jraph.GraphsTuple(
393
+ nodes=node_features,
394
+ senders=senders,
395
+ receivers=receivers,
396
+ edges=edge_features,
397
+ n_node=n_node,
398
+ n_edge=n_edge,
399
+ globals=global_context,
400
+ )
401
+
402
+ def __call__(self, G: nx.Graph):
403
+ nx_g = eval("self.convert_" + self.src_format + "_to_nx(G)")
404
+ dst_g = eval("self.convert_nx_to_" + self.dst_format + "(nx_g)")
405
+ return dst_g
406
+
407
+
408
+ # def convert_nx_to_pyg_data(G: nx.Graph) -> Data:
409
+ # # Initialise dict used to construct Data object
410
+ # data = {"node_id": list(G.nodes())}
411
+
412
+ # G = nx.convert_node_labels_to_integers(G)
413
+
414
+ # # Construct Edge Index
415
+ # edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
416
+
417
+ # # Add node features
418
+ # for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
419
+ # for key, value in feat_dict.items():
420
+ # data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
421
+
422
+ # # Add edge features
423
+ # for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
424
+ # for key, value in feat_dict.items():
425
+ # data[str(key)] = (
426
+ # list(value) if i == 0 else data[str(key)] + list(value)
427
+ # )
428
+
429
+ # # Add graph-level features
430
+ # for feat_name in G.graph:
431
+ # data[str(feat_name)] = [G.graph[feat_name]]
432
+
433
+ # data["edge_index"] = edge_index.view(2, -1)
434
+ # data = Data.from_dict(data)
435
+ # data.num_nodes = G.number_of_nodes()
436
+
437
+ # return data
438
+ def convert_nx_to_pyg_data(G: nx.Graph) -> Data:
439
+ # Initialise dict used to construct Data object
440
+ data = {"node_id": list(G.nodes())}
441
+
442
+ G = nx.convert_node_labels_to_integers(G)
443
+
444
+ # Construct Edge Index
445
+ edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
446
+
447
+ # Add node features
448
+ for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
449
+ for key, value in feat_dict.items():
450
+ data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
451
+
452
+
453
+ # Add edge features
454
+ for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
455
+ for key, value in feat_dict.items():
456
+ if key == 'distance':
457
+ data[str(key)] = (
458
+ [value] if i == 0 else data[str(key)] + [value]
459
+ )
460
+ else:
461
+ data[str(key)] = (
462
+ [list(value)] if i == 0 else data[str(key)] + [list(value)]
463
+ )
464
+
465
+ # Add graph-level features
466
+ for feat_name in G.graph:
467
+ data[str(feat_name)] = [G.graph[feat_name]]
468
+
469
+ data["edge_index"] = edge_index.view(2, -1)
470
+ data = Data.from_dict(data)
471
+ data.num_nodes = G.number_of_nodes()
472
+
473
+ return data
graphs.py ADDED
@@ -0,0 +1,1144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions for working with Protein Structure Graphs."""
2
+ # %%
3
+ # Graphein
4
+ # Author: Arian Jamasb <[email protected]>, Eric Ma, Charlie Harris
5
+ # License: MIT
6
+ # Project Website: https://github.com/a-r-j/graphein
7
+ # Code Repository: https://github.com/a-r-j/graphein
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import traceback
12
+ from functools import partial
13
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
14
+
15
+ import networkx as nx
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ try:
20
+ from biopandas.pdb import PandasPdb
21
+ from biopandas.mmcif import PandasMmcif
22
+ except ImportError:
23
+ raise Exception('You need to install BioPandas and its dependecies to use this model.')
24
+
25
+ from rich.progress import Progress
26
+ from tqdm.contrib.concurrent import process_map
27
+
28
+ try:
29
+ from graphein.protein.config import (
30
+ DSSPConfig,
31
+ GetContactsConfig,
32
+ ProteinGraphConfig,
33
+ )
34
+ from graphein.protein.edges.distance import (
35
+ add_distance_to_edges,
36
+ compute_distmat,
37
+ )
38
+ from graphein.protein.resi_atoms import BACKBONE_ATOMS, RESI_THREE_TO_1
39
+ from graphein.protein.subgraphs import extract_subgraph_from_chains
40
+ from graphein.protein.utils import (
41
+ ProteinGraphConfigurationError,
42
+ compute_rgroup_dataframe,
43
+ filter_dataframe,
44
+ get_protein_name_from_filename,
45
+ three_to_one_with_mods,
46
+ )
47
+ from graphein.rna.constants import RNA_ATOMS
48
+ from graphein.utils.utils import (
49
+ annotate_edge_metadata,
50
+ annotate_graph_metadata,
51
+ annotate_node_metadata,
52
+ compute_edges,
53
+ )
54
+ except ImportError:
55
+ raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')
56
+
57
+ from .utils_convert import biopandas_mmcif2pdb
58
+
59
+ # logging.basicConfig(level="DEBUG")
60
+ log = logging.getLogger(__name__)
61
+
62
+
63
+
64
+ def subset_structure_to_rna(
65
+ df: pd.DataFrame,
66
+ ) -> pd.DataFrame:
67
+ """
68
+ Return a subset of atomic dataframe that contains only certain atom names relevant for RNA structures.
69
+
70
+ :param df: Protein Structure dataframe to subset
71
+ :type df: pd.DataFrame
72
+ :returns: Subsetted protein structure dataframe
73
+ :rtype: pd.DataFrame
74
+ """
75
+ return filter_dataframe(
76
+ df, by_column="atom_name", list_of_values=RNA_ATOMS, boolean=True
77
+ )
78
+
79
+
80
+ def read_pdb_to_dataframe(
81
+ pdb_path: Optional[str] = None,
82
+ pdb_code: Optional[str] = None,
83
+ uniprot_id: Optional[str] = None,
84
+ model_index: int = 1,
85
+ ) -> pd.DataFrame:
86
+ """
87
+ Reads PDB file to ``PandasPDB`` object.
88
+
89
+ Returns ``atomic_df``, which is a dataframe enumerating all atoms and their cartesian coordinates in 3D space. Also
90
+ contains associated metadata from the PDB file.
91
+
92
+ :param pdb_path: path to PDB file. Defaults to ``None``.
93
+ :type pdb_path: str, optional
94
+ :param pdb_code: 4-character PDB accession. Defaults to ``None``.
95
+ :type pdb_code: str, optional
96
+ :param uniprot_id: UniProt ID to build graph from AlphaFoldDB. Defaults to ``None``.
97
+ :type uniprot_id: str, optional
98
+ :param model_index: Index of model to read. Only relevant for structures containing ensembles. Defaults to ``1``.
99
+ :type model_index: int, optional
100
+ :param verbose: print dataframe?
101
+ :type verbose: bool
102
+ :param granularity: Specifies granularity of dataframe. See :class:`~graphein.protein.config.ProteinGraphConfig` for further
103
+ details.
104
+ :type granularity: str
105
+ :returns: ``pd.DataFrame`` containing protein structure
106
+ :rtype: pd.DataFrame
107
+ """
108
+ if pdb_code is None and pdb_path is None and uniprot_id is None:
109
+ raise NameError(
110
+ "One of pdb_code, pdb_path or uniprot_id must be specified!"
111
+ )
112
+
113
+ if pdb_path is not None:
114
+ if pdb_path.endswith('cif'):
115
+ atomic_df = PandasMmcif().read_mmcif(pdb_path)
116
+ atomic_df = biopandas_mmcif2pdb(atomic_df, model_index)
117
+ else:
118
+ atomic_df = PandasPdb().read_pdb(pdb_path)
119
+ else:
120
+ if uniprot_id is not None:
121
+ atomic_df = PandasPdb().fetch_pdb(
122
+ uniprot_id=uniprot_id, source="alphafold2-v2"
123
+ )
124
+ else:
125
+ atomic_df = PandasPdb().fetch_pdb(pdb_code)
126
+
127
+ atomic_df = atomic_df.get_model(model_index)
128
+ if len(atomic_df.df["ATOM"]) == 0:
129
+ raise ValueError(f"No model found for index: {model_index}")
130
+
131
+ return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]])
132
+
133
+
134
+ def label_node_id(df: pd.DataFrame, granularity: str) -> pd.DataFrame:
135
+ df["node_id"] = (
136
+ df["chain_id"].apply(str)
137
+ + ":"
138
+ + df["residue_name"]
139
+ + ":"
140
+ + df["residue_number"].apply(str)
141
+ )
142
+ df["residue_id"] = df["node_id"]
143
+ if granularity == "atom":
144
+ df["node_id"] = df["node_id"] + ":" + df["atom_name"]
145
+ elif granularity in {"rna_atom", "rna_centroid"}:
146
+ df["node_id"] = (
147
+ df["node_id"]
148
+ + ":"
149
+ + df["atom_number"].apply(str)
150
+ + ":"
151
+ + df["atom_name"]
152
+ )
153
+ return df
154
+
155
+
156
+ def deprotonate_structure(df: pd.DataFrame) -> pd.DataFrame:
157
+ """Remove protons from PDB dataframe.
158
+
159
+ :param df: Atomic dataframe.
160
+ :type df: pd.DataFrame
161
+ :returns: Atomic dataframe with all ``atom_name == "H"`` removed.
162
+ :rtype: pd.DataFrame
163
+ """
164
+ log.debug(
165
+ "Deprotonating protein. This removes H atoms from the pdb_df dataframe"
166
+ )
167
+ return filter_dataframe(
168
+ df, by_column="element_symbol", list_of_values=["H"], boolean=False
169
+ )
170
+
171
+
172
+ def convert_structure_to_centroids(df: pd.DataFrame) -> pd.DataFrame:
173
+ """Overwrite existing ``(x, y, z)`` coordinates with centroids of the amino acids.
174
+
175
+ :param df: Pandas Dataframe protein structure to convert into a dataframe of centroid positions.
176
+ :type df: pd.DataFrame
177
+ :return: pd.DataFrame with atoms/residues positions converted into centroid positions.
178
+ :rtype: pd.DataFrame
179
+ """
180
+ log.debug(
181
+ "Converting dataframe to centroids. This averages XYZ coords of the atoms in a residue"
182
+ )
183
+
184
+ centroids = calculate_centroid_positions(df)
185
+ df = df.loc[df["atom_name"] == "CA"].reset_index(drop=True)
186
+ df["x_coord"] = centroids["x_coord"]
187
+ df["y_coord"] = centroids["y_coord"]
188
+ df["z_coord"] = centroids["z_coord"]
189
+
190
+ return df
191
+
192
+
193
+ def subset_structure_to_atom_type(
194
+ df: pd.DataFrame, granularity: str
195
+ ) -> pd.DataFrame:
196
+ """
197
+ Return a subset of atomic dataframe that contains only certain atom names.
198
+
199
+ :param df: Protein Structure dataframe to subset.
200
+ :type df: pd.DataFrame
201
+ :returns: Subsetted protein structure dataframe.
202
+ :rtype: pd.DataFrame
203
+ """
204
+ return filter_dataframe(
205
+ df, by_column="atom_name", list_of_values=[granularity], boolean=True
206
+ )
207
+
208
+
209
+ def remove_insertions(df: pd.DataFrame, keep: str = "first") -> pd.DataFrame:
210
+ """
211
+ This function removes insertions from PDB dataframes.
212
+
213
+ :param df: Protein Structure dataframe to remove insertions from.
214
+ :type df: pd.DataFrame
215
+ :param keep: Specifies which insertion to keep. Options are ``"first"`` or ``"last"``.
216
+ Default is ``"first"``
217
+ :type keep: str
218
+ :return: Protein structure dataframe with insertions removed
219
+ :rtype: pd.DataFrame
220
+ """
221
+ # Catches unnamed insertions
222
+ duplicates = df.duplicated(
223
+ subset=["chain_id", "residue_number", "atom_name"], keep=keep
224
+ )
225
+ df = df[~duplicates]
226
+
227
+ # Catches explicit insertions
228
+ df = filter_dataframe(
229
+ df, by_column="insertion", list_of_values=[""], boolean=True
230
+ )
231
+
232
+ # Remove alt_locs
233
+ df = filter_dataframe(
234
+ df, by_column="alt_loc", list_of_values=["", "A"], boolean=True
235
+ )
236
+
237
+ return df
238
+
239
+
240
+ def filter_hetatms(
241
+ df: pd.DataFrame, keep_hets: List[str]
242
+ ) -> List[pd.DataFrame]:
243
+ """Return hetatms of interest.
244
+
245
+ :param df: Protein Structure dataframe to filter hetatoms from.
246
+ :type df: pd.DataFrame
247
+ :param keep_hets: List of hetero atom names to keep.
248
+ :returns: Protein structure dataframe with heteroatoms removed
249
+ :rtype: pd.DataFrame
250
+ """
251
+ return [df.loc[df["residue_name"] == hetatm] for hetatm in keep_hets]
252
+
253
+
254
+ def process_dataframe(
255
+ protein_df: pd.DataFrame,
256
+ atom_df_processing_funcs: Optional[List[Callable]] = None,
257
+ hetatom_df_processing_funcs: Optional[List[Callable]] = None,
258
+ granularity: str = "centroids",
259
+ chain_selection: str = "all",
260
+ insertions: bool = False,
261
+ deprotonate: bool = True,
262
+ keep_hets: List[str] = [],
263
+ verbose: bool = False,
264
+ ) -> pd.DataFrame:
265
+ """
266
+ Process ATOM and HETATM dataframes to produce singular dataframe used for graph construction.
267
+
268
+ :param protein_df: Dataframe to process.
269
+ Should be the object returned from :func:`~graphein.protein.graphs.read_pdb_to_dataframe`.
270
+ :type protein_df: pd.DataFrame
271
+ :param atom_df_processing_funcs: List of functions to process dataframe. These must take in a dataframe and return a
272
+ dataframe. Defaults to None.
273
+ :type atom_df_processing_funcs: List[Callable], optional
274
+ :param hetatom_df_processing_funcs: List of functions to process the hetatom dataframe. These must take in a dataframe and return a dataframe
275
+ :type hetatom_df_processing_funcs: List[Callable], optional
276
+ :param granularity: The level of granularity for the graph. This determines the node definition.
277
+ Acceptable values include: ``"centroids"``, ``"atoms"``,
278
+ any of the atom_names in the PDB file (e.g. ``"CA"``, ``"CB"``, ``"OG"``, etc.).
279
+ See: :const:`~graphein.protein.config.GRAPH_ATOMS` and :const:`~graphein.protein.config.GRANULARITY_OPTS`.
280
+ :type granularity: str
281
+ :param insertions: Whether or not to keep insertions.
282
+ :param insertions: bool
283
+ :param deprotonate: Whether or not to remove hydrogen atoms (i.e. deprotonation).
284
+ :type deprotonate: bool
285
+ :param keep_hets: Hetatoms to keep. Defaults to an empty list.
286
+ To keep a hetatom, pass it inside a list of hetatom names to keep.
287
+ :type keep_hets: List[str]
288
+ :param verbose: Verbosity level.
289
+ :type verbose: bool
290
+ :param chain_selection: Which protein chain to select. Defaults to ``"all"``. Eg can use ``"ACF"``
291
+ to select 3 chains (``A``, ``C`` & ``F``)
292
+ :type chain_selection: str
293
+ :return: A protein dataframe that can be consumed by
294
+ other graph construction functions.
295
+ :rtype: pd.DataFrame
296
+ """
297
+ protein_df = label_node_id(protein_df, granularity=granularity)
298
+ # TODO: Need to properly define what "granularity" is supposed to do.
299
+ atoms = filter_dataframe(
300
+ protein_df,
301
+ by_column="record_name",
302
+ list_of_values=["ATOM"],
303
+ boolean=True,
304
+ )
305
+ hetatms = filter_dataframe(
306
+ protein_df,
307
+ by_column="record_name",
308
+ list_of_values=["HETATM"],
309
+ boolean=True,
310
+ )
311
+
312
+ # This block enables processing via a list of supplied functions operating on the atom and hetatom dataframes
313
+ # If these are provided, the dataframe returned will be computed only from these and the default workflow
314
+ # below this block will not execute.
315
+ if atom_df_processing_funcs is not None:
316
+ for func in atom_df_processing_funcs:
317
+ atoms = func(atoms)
318
+ if hetatom_df_processing_funcs is None:
319
+ return atoms
320
+
321
+ if hetatom_df_processing_funcs is not None:
322
+ for func in hetatom_df_processing_funcs:
323
+ hetatms = func(hetatms)
324
+ return pd.concat([atoms, hetatms])
325
+
326
+ if keep_hets:
327
+ hetatms_to_keep = filter_hetatms(hetatms, keep_hets)
328
+ atoms = pd.concat([atoms] + hetatms_to_keep)
329
+
330
+ # Deprotonate structure by removing H atoms
331
+ if deprotonate:
332
+ atoms = deprotonate_structure(atoms)
333
+
334
+ # Restrict DF to desired granularity
335
+ if granularity == "atom":
336
+ pass
337
+ elif granularity in {"centroids", "rna_centroid"}:
338
+ atoms = convert_structure_to_centroids(atoms)
339
+ elif granularity == "rna_atom":
340
+ atoms = subset_structure_to_rna(atoms)
341
+ else:
342
+ atoms = subset_structure_to_atom_type(atoms, granularity)
343
+
344
+ protein_df = atoms
345
+
346
+ # Remove alt_loc residues
347
+ if not insertions:
348
+ protein_df = remove_insertions(protein_df)
349
+
350
+ # perform chain selection
351
+ protein_df = select_chains(
352
+ protein_df, chain_selection=chain_selection, verbose=verbose
353
+ )
354
+
355
+ log.debug(f"Detected {len(protein_df)} total nodes")
356
+
357
+ # Sort dataframe to place HETATMs
358
+ protein_df = sort_dataframe(protein_df)
359
+
360
+ return protein_df
361
+
362
+
363
+ def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame:
364
+ """Sorts a protein dataframe by chain->residue number->atom number
365
+
366
+ This is useful for distributing hetatms/modified residues through the DF.
367
+
368
+ :param df: Protein dataframe to sort.
369
+ :type df: pd.DataFrame
370
+ :return: Sorted protein dataframe.
371
+ :rtype: pd.DataFrame
372
+ """
373
+ return df.sort_values(by=["chain_id", "residue_number", "atom_number"])
374
+
375
+
376
+ def assign_node_id_to_dataframe(
377
+ protein_df: pd.DataFrame, granularity: str
378
+ ) -> pd.DataFrame:
379
+ """
380
+ Assigns the node ID back to the ``pdb_df`` dataframe
381
+
382
+ :param protein_df: Structure Dataframe
383
+ :type protein_df: pd.DataFrame
384
+ :param granularity: Granularity of graph. Atom-level,
385
+ residue (e.g. ``CA``) or ``centroids``.
386
+ See: :const:`~graphein.protein.config.GRAPH_ATOMS`
387
+ and :const:`~graphein.protein.config.GRANULARITY_OPTS`.
388
+ :type granularity: str
389
+ :return: Returns dataframe with added ``node_ids``
390
+ :rtype: pd.DataFrame
391
+ """
392
+ protein_df["node_id"] = (
393
+ protein_df["chain_id"].apply(str)
394
+ + ":"
395
+ + protein_df["residue_name"]
396
+ + ":"
397
+ + protein_df["residue_number"].apply(str)
398
+ )
399
+ if granularity in {"atom", "rna_atom"}:
400
+ protein_df[
401
+ "node_id"
402
+ ] = f'{protein_df["node_id"]}:{protein_df["atom_name"]}'
403
+
404
+
405
+ def select_chains(
406
+ protein_df: pd.DataFrame, chain_selection: str, verbose: bool = False
407
+ ) -> pd.DataFrame:
408
+ """
409
+ Extracts relevant chains from ``protein_df``.
410
+
411
+ :param protein_df: pandas dataframe of PDB subsetted to relevant atoms
412
+ (``CA``, ``CB``).
413
+ :type protein_df: pd.DataFrame
414
+ :param chain_selection: Specifies chains that should be extracted from
415
+ the larger complexed structure.
416
+ :type chain_selection: str
417
+ :param verbose: Print dataframe?
418
+ :type verbose: bool
419
+ :return: Protein structure dataframe containing only entries in the
420
+ chain selection.
421
+ :rtype: pd.DataFrame
422
+ """
423
+ if chain_selection != "all":
424
+ protein_df = filter_dataframe(
425
+ protein_df,
426
+ by_column="chain_id",
427
+ list_of_values=list(chain_selection),
428
+ boolean=True,
429
+ )
430
+
431
+ return protein_df
432
+
433
+
434
+ def initialise_graph_with_metadata(
435
+ protein_df: pd.DataFrame,
436
+ raw_pdb_df: pd.DataFrame,
437
+ granularity: str,
438
+ name: Optional[str] = None,
439
+ pdb_code: Optional[str] = None,
440
+ pdb_path: Optional[str] = None,
441
+ ) -> nx.Graph:
442
+ """
443
+ Initializes the nx Graph object with initial metadata.
444
+
445
+ :param protein_df: Processed Dataframe of protein structure.
446
+ :type protein_df: pd.DataFrame
447
+ :param raw_pdb_df: Unprocessed dataframe of protein structure for comparison and traceability downstream.
448
+ :type raw_pdb_df: pd.DataFrame
449
+ :param granularity: Granularity of the graph (eg ``"atom"``, ``"CA"``, ``"CB"`` etc or ``"centroid"``).
450
+ See: :const:`~graphein.protein.config.GRAPH_ATOMS` and :const:`~graphein.protein.config.GRANULARITY_OPTS`.
451
+ :type granularity: str
452
+ :param name: specified given name for the graph. If None, the PDB code or the file name will be used to name the graph.
453
+ :type name: Optional[str], defaults to ``None``
454
+ :param pdb_code: PDB ID / Accession code, if the PDB is available on the PDB database.
455
+ :type pdb_code: Optional[str], defaults to ``None``
456
+ :param pdb_path: path to local PDB file, if constructing a graph from a local file.
457
+ :type pdb_path: Optional[str], defaults to ``None``
458
+ :return: Returns initial protein structure graph with metadata.
459
+ :rtype: nx.Graph
460
+ """
461
+
462
+ # Get name for graph if no name was provided
463
+ if name is None:
464
+ if pdb_path is not None:
465
+ name = get_protein_name_from_filename(pdb_path)
466
+ else:
467
+ name = pdb_code
468
+
469
+ G = nx.Graph(
470
+ name=name,
471
+ pdb_code=pdb_code,
472
+ pdb_path=pdb_path,
473
+ chain_ids=list(protein_df["chain_id"].unique()),
474
+ pdb_df=protein_df,
475
+ raw_pdb_df=raw_pdb_df,
476
+ rgroup_df=compute_rgroup_dataframe(remove_insertions(raw_pdb_df)),
477
+ coords=np.asarray(protein_df[["x_coord", "y_coord", "z_coord"]]),
478
+ )
479
+
480
+ # Create graph and assign intrinsic graph-level metadata
481
+ G.graph["node_type"] = granularity
482
+
483
+ # Add Sequences to graph metadata
484
+ for c in G.graph["chain_ids"]:
485
+ if granularity == "rna_atom":
486
+ sequence = protein_df.loc[protein_df["chain_id"] == c][
487
+ "residue_name"
488
+ ].str.cat()
489
+ else:
490
+ sequence = (
491
+ protein_df.loc[protein_df["chain_id"] == c]["residue_name"]
492
+ .apply(three_to_one_with_mods)
493
+ .str.cat()
494
+ )
495
+ G.graph[f"sequence_{c}"] = sequence
496
+ return G
497
+
498
+
499
+ def add_nodes_to_graph(
500
+ G: nx.Graph,
501
+ protein_df: Optional[pd.DataFrame] = None,
502
+ verbose: bool = False,
503
+ ) -> nx.Graph:
504
+ """Add nodes into protein graph.
505
+
506
+ :param G: ``nx.Graph`` with metadata to populate with nodes.
507
+ :type G: nx.Graph
508
+ :protein_df: DataFrame of protein structure containing nodes & initial node metadata to add to the graph.
509
+ :type protein_df: pd.DataFrame, optional
510
+ :param verbose: Controls verbosity of this step.
511
+ :type verbose: bool
512
+ :returns: nx.Graph with nodes added.
513
+ :rtype: nx.Graph
514
+ """
515
+
516
+ # If no protein dataframe is supplied, use the one stored in the Graph object
517
+ if protein_df is None:
518
+ protein_df = G.graph["pdb_df"]
519
+ # Assign intrinsic node attributes
520
+ chain_id = protein_df["chain_id"].apply(str)
521
+ residue_name = protein_df["residue_name"]
522
+ residue_number = protein_df["residue_number"] # .apply(str)
523
+ coords = np.asarray(protein_df[["x_coord", "y_coord", "z_coord"]])
524
+ b_factor = protein_df["b_factor"]
525
+ atom_type = protein_df["atom_name"]
526
+ nodes = protein_df["node_id"]
527
+ element_symbol = protein_df["element_symbol"]
528
+ G.add_nodes_from(nodes)
529
+
530
+ # Set intrinsic node attributes
531
+ nx.set_node_attributes(G, dict(zip(nodes, chain_id)), "chain_id")
532
+ nx.set_node_attributes(G, dict(zip(nodes, residue_name)), "residue_name")
533
+ nx.set_node_attributes(
534
+ G, dict(zip(nodes, residue_number)), "residue_number"
535
+ )
536
+ nx.set_node_attributes(G, dict(zip(nodes, atom_type)), "atom_type")
537
+ nx.set_node_attributes(
538
+ G, dict(zip(nodes, element_symbol)), "element_symbol"
539
+ )
540
+ nx.set_node_attributes(G, dict(zip(nodes, coords)), "coords")
541
+ nx.set_node_attributes(G, dict(zip(nodes, b_factor)), "b_factor")
542
+
543
+ # TODO: include charge, line_idx for traceability?
544
+ if verbose:
545
+ print(nx.info(G))
546
+ print(G.nodes())
547
+
548
+ return G
549
+
550
+
551
+ def calculate_centroid_positions(
552
+ atoms: pd.DataFrame, verbose: bool = False
553
+ ) -> pd.DataFrame:
554
+ """
555
+ Calculates position of sidechain centroids.
556
+
557
+ :param atoms: ATOM df of protein structure.
558
+ :type atoms: pd.DataFrame
559
+ :param verbose: bool controlling verbosity.
560
+ :type verbose: bool
561
+ :return: centroids (df).
562
+ :rtype: pd.DataFrame
563
+ """
564
+ centroids = (
565
+ atoms.groupby("residue_number")
566
+ .mean()[["x_coord", "y_coord", "z_coord"]]
567
+ .reset_index()
568
+ )
569
+ if verbose:
570
+ print(f"Calculated {len(centroids)} centroid nodes")
571
+ log.debug(f"Calculated {len(centroids)} centroid nodes")
572
+ return centroids
573
+
574
+
575
+ def compute_edges(
576
+ G: nx.Graph,
577
+ funcs: List[Callable],
578
+ get_contacts_config: Optional[GetContactsConfig] = None,
579
+ ) -> nx.Graph:
580
+ """
581
+ Computes edges for the protein structure graph. Will compute a pairwise
582
+ distance matrix between nodes which is
583
+ added to the graph metadata to facilitate some edge computations.
584
+
585
+ :param G: nx.Graph with nodes to add edges to.
586
+ :type G: nx.Graph
587
+ :param funcs: List of edge construction functions.
588
+ :type funcs: List[Callable]
589
+ :param get_contacts_config: Config object for ``GetContacts`` if
590
+ intramolecular edges are being used.
591
+ :type get_contacts_config: graphein.protein.config.GetContactsConfig
592
+ :return: Graph with added edges.
593
+ :rtype: nx.Graph
594
+ """
595
+ # This control flow prevents unnecessary computation of the distance matrices
596
+ if "config" in G.graph:
597
+ if G.graph["config"].granularity == "atom":
598
+ G.graph["atomic_dist_mat"] = compute_distmat(G.graph["pdb_df"])
599
+ else:
600
+ G.graph["dist_mat"] = compute_distmat(G.graph["pdb_df"])
601
+
602
+ for func in funcs:
603
+ func(G)
604
+
605
+ return add_distance_to_edges(G)
606
+
607
+
608
+ def construct_graph(
609
+ config: Optional[ProteinGraphConfig] = None,
610
+ name: Optional[str] = None,
611
+ pdb_path: Optional[str] = None,
612
+ uniprot_id: Optional[str] = None,
613
+ pdb_code: Optional[str] = None,
614
+ chain_selection: str = "all",
615
+ model_index: int = 1,
616
+ df_processing_funcs: Optional[List[Callable]] = None,
617
+ edge_construction_funcs: Optional[List[Callable]] = None,
618
+ edge_annotation_funcs: Optional[List[Callable]] = None,
619
+ node_annotation_funcs: Optional[List[Callable]] = None,
620
+ graph_annotation_funcs: Optional[List[Callable]] = None,
621
+ ) -> nx.Graph:
622
+ """
623
+ Constructs protein structure graph from a ``pdb_code`` or ``pdb_path``.
624
+
625
+ Users can provide a :class:`~graphein.protein.config.ProteinGraphConfig`
626
+ object to specify construction parameters.
627
+
628
+ However, config parameters can be overridden by passing arguments directly to the function.
629
+
630
+ :param config: :class:`~graphein.protein.config.ProteinGraphConfig` object. If None, defaults to config in ``graphein.protein.config``.
631
+ :type config: graphein.protein.config.ProteinGraphConfig, optional
632
+ :param name: an optional given name for the graph. the PDB ID or PDB file name will be used if not specified.
633
+ :type name: str, optional
634
+ :param pdb_path: Path to ``pdb_file`` when constructing a graph from a local pdb file. Default is ``None``.
635
+ :type pdb_path: Optional[str], defaults to ``None``
636
+ :param pdb_code: A 4-character PDB ID / accession to be used to construct the graph, if available. Default is ``None``.
637
+ :type pdb_code: Optional[str], defaults to ``None``
638
+ :param uniprot_id: UniProt accession ID to build graph from AlphaFold2DB. Default is ``None``.
639
+ :type uniprot_id: str, optional
640
+ :param chain_selection: String of polypeptide chains to include in graph. E.g ``"ABDF"`` or ``"all"``. Default is ``"all"``.
641
+ :type chain_selection: str
642
+ :param model_index: Index of model to use in the case of structural ensembles. Default is ``1``.
643
+ :type model_index: int
644
+ :param df_processing_funcs: List of dataframe processing functions. Default is ``None``.
645
+ :type df_processing_funcs: List[Callable], optional
646
+ :param edge_construction_funcs: List of edge construction functions. Default is ``None``.
647
+ :type edge_construction_funcs: List[Callable], optional
648
+ :param edge_annotation_funcs: List of edge annotation functions. Default is ``None``.
649
+ :type edge_annotation_funcs: List[Callable], optional
650
+ :param node_annotation_funcs: List of node annotation functions. Default is ``None``.
651
+ :type node_annotation_funcs: List[Callable], optional
652
+ :param graph_annotation_funcs: List of graph annotation function. Default is ``None``.
653
+ :type graph_annotation_funcs: List[Callable]
654
+ :return: Protein Structure Graph
655
+ :rtype: nx.Graph
656
+ """
657
+
658
+ if pdb_code is None and pdb_path is None and uniprot_id is None:
659
+ raise ValueError(
660
+ "Either a PDB ID, UniProt ID or a path to a local PDB file"
661
+ " must be specified to construct a graph"
662
+ )
663
+
664
+ # If no config is provided, use default
665
+ if config is None:
666
+ config = ProteinGraphConfig()
667
+ with Progress(transient=True) as progress:
668
+ task1 = progress.add_task("Reading PDB file...", total=1)
669
+ # Get name from pdb_file is no pdb_code is provided
670
+ # if pdb_path and (pdb_code is None and uniprot_id is None):
671
+ # pdb_code = get_protein_name_from_filename(pdb_path)
672
+ # pdb_code = pdb_code if len(pdb_code) == 4 else None
673
+ progress.advance(task1)
674
+
675
+ # If config params are provided, overwrite them
676
+ config.protein_df_processing_functions = (
677
+ df_processing_funcs
678
+ if config.protein_df_processing_functions is None
679
+ else config.protein_df_processing_functions
680
+ )
681
+ config.edge_construction_functions = (
682
+ edge_construction_funcs
683
+ if config.edge_construction_functions is None
684
+ else config.edge_construction_functions
685
+ )
686
+ config.node_metadata_functions = (
687
+ node_annotation_funcs
688
+ if config.node_metadata_functions is None
689
+ else config.node_metadata_functions
690
+ )
691
+ config.graph_metadata_functions = (
692
+ graph_annotation_funcs
693
+ if config.graph_metadata_functions is None
694
+ else config.graph_metadata_functions
695
+ )
696
+ config.edge_metadata_functions = (
697
+ edge_annotation_funcs
698
+ if config.edge_metadata_functions is None
699
+ else config.edge_metadata_functions
700
+ )
701
+
702
+ raw_df = read_pdb_to_dataframe(
703
+ pdb_path,
704
+ pdb_code,
705
+ uniprot_id,
706
+ model_index=model_index,
707
+ )
708
+
709
+
710
+ task2 = progress.add_task("Processing PDB dataframe...", total=1)
711
+ # raw_df = label_node_id(raw_df, granularity=config.granularity)
712
+ # raw_df.df["ATOM"] = label_node_id(
713
+ # raw_df.df["ATOM"], granularity=config.granularity
714
+ # )
715
+ # raw_df.df["HETATM"] = label_node_id(
716
+ # raw_df.df["HETATM"], granularity=config.granularity
717
+ # )
718
+ raw_df = sort_dataframe(raw_df)
719
+ protein_df = process_dataframe(
720
+ raw_df,
721
+ chain_selection=chain_selection,
722
+ granularity=config.granularity,
723
+ insertions=config.insertions,
724
+ keep_hets=config.keep_hets,
725
+ )
726
+ progress.advance(task2)
727
+
728
+ task3 = progress.add_task("Initializing graph...", total=1)
729
+ # Initialise graph with metadata
730
+ g = initialise_graph_with_metadata(
731
+ protein_df=protein_df,
732
+ raw_pdb_df=raw_df,
733
+ name=name,
734
+ pdb_code=pdb_code,
735
+ pdb_path=pdb_path,
736
+ granularity=config.granularity,
737
+ )
738
+ # Add nodes to graph
739
+ g = add_nodes_to_graph(g)
740
+ # Add config to graph
741
+ g.graph["config"] = config
742
+ g.graph["path"] = g.graph["pdb_path"]
743
+
744
+ # Annotate additional node metadata
745
+ if config.node_metadata_functions is not None:
746
+ g = annotate_node_metadata(g, config.node_metadata_functions)
747
+ progress.advance(task3)
748
+ task4 = progress.add_task("Constructing edges...", total=1)
749
+ # Compute graph edges
750
+ g = compute_edges(
751
+ g,
752
+ funcs=config.edge_construction_functions,
753
+ get_contacts_config=None,
754
+ )
755
+ progress.advance(task4)
756
+
757
+ # Annotate additional graph metadata
758
+ # print(g.graph['dssp_df'])
759
+ if config.graph_metadata_functions is not None:
760
+ g = annotate_graph_metadata(g, config.graph_metadata_functions)
761
+
762
+ # Annotate additional edge metadata
763
+ if config.edge_metadata_functions is not None:
764
+ g = annotate_edge_metadata(g, config.edge_metadata_functions)
765
+
766
+ return g
767
+
768
+
769
+ def _mp_graph_constructor(
770
+ args: Tuple[str, str, int], source: str, config: ProteinGraphConfig
771
+ ) -> Union[nx.Graph, None]:
772
+ """
773
+ Protein graph constructor for use in multiprocessing several protein structure graphs.
774
+
775
+ :param args: Tuple of pdb code/path and the chain selection for that PDB.
776
+ :type args: Tuple[str, str]
777
+ :param use_pdb_code: Whether we are using ``"pdb_code"``s, ``pdb_path``s or ``"uniprot_id"``s.
778
+ :type use_pdb_code: bool
779
+ :param config: Protein structure graph construction config (see: :class:`graphein.protein.config.ProteinGraphConfig`).
780
+ :type config: ProteinGraphConfig
781
+ :return: Protein structure graph or ``None`` if an error is encountered.
782
+ :rtype: Union[nx.Graph, None]
783
+ """
784
+ log.info(
785
+ f"Constructing graph for: {args[0]}. Chain selection: {args[1]}. Model index: {args[2]}"
786
+ )
787
+ func = partial(construct_graph, config=config)
788
+ try:
789
+ if source == "pdb_code":
790
+ return func(
791
+ pdb_code=args[0], chain_selection=args[1], model_index=args[2]
792
+ )
793
+ elif source == "pdb_path":
794
+ return func(
795
+ pdb_path=args[0], chain_selection=args[1], model_index=args[2]
796
+ )
797
+ elif source == "uniprot_id":
798
+ return func(
799
+ uniprot_id=args[0],
800
+ chain_selection=args[1],
801
+ model_index=args[2],
802
+ )
803
+
804
+ except Exception as ex:
805
+ log.info(
806
+ f"Graph construction error (PDB={args[0]})! {traceback.format_exc()}"
807
+ )
808
+ log.info(ex)
809
+ return None
810
+
811
+
812
+ def construct_graphs_mp(
813
+ pdb_code_it: Optional[List[str]] = None,
814
+ pdb_path_it: Optional[List[str]] = None,
815
+ uniprot_id_it: Optional[List[str]] = None,
816
+ chain_selections: Optional[List[str]] = None,
817
+ model_indices: Optional[List[str]] = None,
818
+ config: ProteinGraphConfig = ProteinGraphConfig(),
819
+ num_cores: int = 16,
820
+ return_dict: bool = True,
821
+ out_path: Optional[str] = None,
822
+ ) -> Union[List[nx.Graph], Dict[str, nx.Graph]]:
823
+ """
824
+ Constructs protein graphs for a list of pdb codes or pdb paths using multiprocessing.
825
+
826
+ :param pdb_code_it: List of pdb codes to use for protein graph construction
827
+ :type pdb_code_it: Optional[List[str]], defaults to ``None``
828
+ :param pdb_path_it: List of paths to PDB files to use for protein graph construction
829
+ :type pdb_path_it: Optional[List[str]], defaults to ``None``
830
+ :param chain_selections: List of chains to select from the protein structures (e.g. ``["ABC", "A", "L", "CD"...]``)
831
+ :type chain_selections: Optional[List[str]], defaults to ``None``
832
+ :param model_indices: List of model indices to use for protein graph construction. Only relevant for structures containing ensembles of models.
833
+ :type model_indices: Optional[List[str]], defaults to ``None``
834
+ :param config: ProteinGraphConfig to use.
835
+ :type config: graphein.protein.config.ProteinGraphConfig, defaults to default config params
836
+ :param num_cores: Number of cores to use for multiprocessing. The more the merrier
837
+ :type num_cores: int, defaults to ``16``
838
+ :param return_dict: Whether or not to return a dictionary (indexed by pdb codes/paths) or a list of graphs.
839
+ :type return_dict: bool, default to ``True``
840
+ :param out_path: Path to save the graphs to. If None, graphs are not saved.
841
+ :type out_path: Optional[str], defaults to ``None``
842
+ :return: Iterable of protein graphs. None values indicate there was a problem in constructing the graph for this particular pdb
843
+ :rtype: Union[List[nx.Graph], Dict[str, nx.Graph]]
844
+ """
845
+ assert (
846
+ pdb_code_it is not None or pdb_path_it is not None
847
+ ), "Iterable of pdb codes, pdb paths or uniprot IDs required."
848
+
849
+ if pdb_code_it is not None:
850
+ pdbs = pdb_code_it
851
+ source = "pdb_code"
852
+
853
+ if pdb_path_it is not None:
854
+ pdbs = pdb_path_it
855
+ source = "pdb_path"
856
+
857
+ if uniprot_id_it is not None:
858
+ pdbs = uniprot_id_it
859
+ source = "uniprot_id"
860
+
861
+ if chain_selections is None:
862
+ chain_selections = ["all"] * len(pdbs)
863
+
864
+ if model_indices is None:
865
+ model_indices = [1] * len(pdbs)
866
+
867
+ constructor = partial(_mp_graph_constructor, source=source, config=config)
868
+
869
+ graphs = list(
870
+ process_map(
871
+ constructor,
872
+ [
873
+ (pdb, chain_selections[i], model_indices[i])
874
+ for i, pdb in enumerate(pdbs)
875
+ ],
876
+ max_workers=num_cores,
877
+ )
878
+ )
879
+ if out_path is not None:
880
+ [
881
+ nx.write_gpickle(
882
+ g, str(f"{out_path}/" + f"{g.graph['name']}.pickle")
883
+ )
884
+ for g in graphs
885
+ ]
886
+
887
+ if return_dict:
888
+ graphs = {pdb: graphs[i] for i, pdb in enumerate(pdbs)}
889
+
890
+ return graphs
891
+
892
+
893
+ def compute_chain_graph(
894
+ g: nx.Graph,
895
+ chain_list: Optional[List[str]] = None,
896
+ remove_self_loops: bool = False,
897
+ return_weighted_graph: bool = False,
898
+ ) -> Union[nx.Graph, nx.MultiGraph]:
899
+ """Computes a chain-level graph from a protein structure graph.
900
+
901
+ This graph features nodes as individual chains in a complex and edges as
902
+ the interactions between constituent nodes in each chain. You have the
903
+ option of returning an unweighted graph (multigraph,
904
+ ``return_weighted_graph=False``) or a weighted graph
905
+ (``return_weighted_graph=True``). The difference between these is the
906
+ unweighted graph features and edge for each interaction between chains
907
+ (ie the number of edges will be equal to the number of edges in the input
908
+ protein structure graph), while the weighted graph sums these interactions
909
+ to a single edge between chains with the counts stored as features.
910
+
911
+ :param g: A protein structure graph to compute the chain graph of.
912
+ :type g: nx.Graph
913
+ :param chain_list: A list of chains to extract from the input graph.
914
+ If ``None``, all chains will be used. This is provided as input to
915
+ ``extract_subgraph_from_chains``. Default is ``None``.
916
+ :type chain_list: Optional[List[str]]
917
+ :param remove_self_loops: Whether to remove self-loops from the graph.
918
+ Default is False.
919
+ :type remove_self_loops: bool
920
+ :return: A chain-level graph.
921
+ :rtype: Union[nx.Graph, nx.MultiGraph]
922
+ """
923
+ # If we are extracting specific chains, do it here.
924
+ if chain_list is not None:
925
+ g = extract_subgraph_from_chains(g, chain_list)
926
+
927
+ # Initialise new graph with Metadata
928
+ h = nx.MultiGraph()
929
+ h.graph = g.graph
930
+ h.graph["node_type"] = "chain"
931
+
932
+ # Set nodes
933
+ nodes_per_chain = {chain: 0 for chain in g.graph["chain_ids"]}
934
+ sequences = {chain: "" for chain in g.graph["chain_ids"]}
935
+ for n, d in g.nodes(data=True):
936
+ nodes_per_chain[d["chain_id"]] += 1
937
+ sequences[d["chain_id"]] += RESI_THREE_TO_1[d["residue_name"]]
938
+
939
+ h.add_nodes_from(g.graph["chain_ids"])
940
+
941
+ for n, d in h.nodes(data=True):
942
+ d["num_residues"] = nodes_per_chain[n]
943
+ d["sequence"] = sequences[n]
944
+
945
+ # Add edges
946
+ for u, v, d in g.edges(data=True):
947
+ h.add_edge(
948
+ g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"]
949
+ )
950
+ # Remove self-loops if necessary. Checks for equality between nodes in a given edge.
951
+ if remove_self_loops:
952
+ edges_to_remove: List[Tuple[str]] = [
953
+ (u, v) for u, v in h.edges() if u == v
954
+ ]
955
+ h.remove_edges_from(edges_to_remove)
956
+
957
+ # Compute a weighted graph if required.
958
+ if return_weighted_graph:
959
+ return compute_weighted_graph_from_multigraph(h)
960
+ return h
961
+
962
+
963
+ def compute_weighted_graph_from_multigraph(g: nx.MultiGraph) -> nx.Graph:
964
+ """Computes a weighted graph from a multigraph.
965
+
966
+ This function is used to convert a multigraph to a weighted graph. The
967
+ weights of the edges are the number of interactions between the nodes.
968
+
969
+ :param g: A multigraph.
970
+ :type g: nx.MultiGraph
971
+ :return: A weighted graph.
972
+ :rtype: nx.Graph
973
+ """
974
+ H = nx.Graph()
975
+ H.graph = g.graph
976
+ H.add_nodes_from(g.nodes(data=True))
977
+ for u, v, d in g.edges(data=True):
978
+ if H.has_edge(u, v):
979
+ H[u][v]["weight"] += len(d["kind"])
980
+ H[u][v]["kind"].update(d["kind"])
981
+ for kind in list(d["kind"]):
982
+ try:
983
+ H[u][v][kind] += 1
984
+ except KeyError:
985
+ H[u][v][kind] = 1
986
+ else:
987
+ H.add_edge(u, v, weight=len(d["kind"]), kind=d["kind"])
988
+ for kind in list(d["kind"]):
989
+ H[u][v][kind] = 1
990
+ return H
991
+
992
+
993
+ def number_groups_of_runs(list_of_values: List[Any]) -> List[str]:
994
+ """Numbers groups of runs in a list of values.
995
+
996
+ E.g. ``["A", "A", "B", "A", "A", "A", "B", "B"] ->
997
+ ["A1", "A1", "B1", "A2", "A2", "A2", "B2", "B2"]``
998
+
999
+ :param list_of_values: List of values to number.
1000
+ :type list_of_values: List[Any]
1001
+ :return: List of numbered values.
1002
+ :rtype: List[str]
1003
+ """
1004
+ df = pd.DataFrame({"val": list_of_values})
1005
+ df["idx"] = df["val"].shift() != df["val"]
1006
+ df["sum"] = df.groupby("val")["idx"].cumsum()
1007
+ return list(df["val"].astype(str) + df["sum"].astype(str))
1008
+
1009
+
1010
+ def compute_secondary_structure_graph(
1011
+ g: nx.Graph,
1012
+ allowable_ss_elements: Optional[List[str]] = None,
1013
+ remove_non_ss: bool = True,
1014
+ remove_self_loops: bool = False,
1015
+ return_weighted_graph: bool = False,
1016
+ ) -> Union[nx.Graph, nx.MultiGraph]:
1017
+ """Computes a secondary structure graph from a protein structure graph.
1018
+
1019
+ :param g: A protein structure graph to compute the secondary structure
1020
+ graph of.
1021
+ :type g: nx.Graph
1022
+ :param remove_non_ss: Whether to remove non-secondary structure nodes from
1023
+ the graph. These are denoted as ``"-"`` by DSSP. Default is True.
1024
+ :type remove_non_ss: bool
1025
+ :param remove_self_loops: Whether to remove self-loops from the graph.
1026
+ Default is ``False``.
1027
+ :type remove_self_loops: bool
1028
+ :param return_weighted_graph: Whether to return a weighted graph.
1029
+ Default is False.
1030
+ :type return_weighted_graph: bool
1031
+ :raises ProteinGraphConfigurationError: If the protein structure graph is
1032
+ not configured correctly with secondary structure assignments on all
1033
+ nodes.
1034
+ :return: A secondary structure graph.
1035
+ :rtype: Union[nx.Graph, nx.MultiGraph]
1036
+ """
1037
+ # Initialise list of secondary structure elements we use to build the graph
1038
+ ss_list: List[str] = []
1039
+
1040
+ # Check nodes have secondary structure assignment & store them in list
1041
+ for _, d in g.nodes(data=True):
1042
+ if "ss" not in d.keys():
1043
+ raise ProteinGraphConfigurationError(
1044
+ "Secondary structure not defined for all nodes."
1045
+ )
1046
+ ss_list.append(d["ss"])
1047
+
1048
+ # Number SS elements
1049
+ ss_list = pd.Series(number_groups_of_runs(ss_list))
1050
+ ss_list.index = list(g.nodes())
1051
+
1052
+ # Remove unstructured elements if necessary
1053
+ if remove_non_ss:
1054
+ ss_list = ss_list[~ss_list.str.contains("-")]
1055
+ # Subset to only allowable SS elements if necessary
1056
+ if allowable_ss_elements:
1057
+ ss_list = ss_list[
1058
+ ss_list.str.contains("|".join(allowable_ss_elements))
1059
+ ]
1060
+
1061
+ constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(
1062
+ ss_list.values
1063
+ )
1064
+ constituent_residues = {
1065
+ k: list(v) for k, v in constituent_residues.items()
1066
+ }
1067
+ residue_counts: Dict[str, int] = ss_list.groupby(ss_list).count().to_dict()
1068
+
1069
+ # Add Nodes from secondary structure list
1070
+ h = nx.MultiGraph()
1071
+ h.add_nodes_from(ss_list)
1072
+ nx.set_node_attributes(h, residue_counts, "residue_counts")
1073
+ nx.set_node_attributes(h, constituent_residues, "constituent_residues")
1074
+ # Assign ss
1075
+ for n, d in h.nodes(data=True):
1076
+ d["ss"] = n[0]
1077
+
1078
+ # Add graph-level metadata
1079
+ h.graph = g.graph
1080
+ h.graph["node_type"] = "secondary_structure"
1081
+
1082
+ # Iterate over edges in source graph and add SS-SS edges to new graph.
1083
+ for u, v, d in g.edges(data=True):
1084
+ try:
1085
+ h.add_edge(
1086
+ ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}"
1087
+ )
1088
+ except KeyError as e:
1089
+ log.debug(
1090
+ f"Edge {u}-{v} not added to secondary structure graph. \
1091
+ Reason: {e} not in graph"
1092
+ )
1093
+
1094
+ # Remove self-loops if necessary.
1095
+ # Checks for equality between nodes in a given edge.
1096
+ if remove_self_loops:
1097
+ edges_to_remove: List[Tuple[str]] = [
1098
+ (u, v) for u, v in h.edges() if u == v
1099
+ ]
1100
+ h.remove_edges_from(edges_to_remove)
1101
+
1102
+ # Create weighted graph from h
1103
+ if return_weighted_graph:
1104
+ return compute_weighted_graph_from_multigraph(h)
1105
+ return h
1106
+
1107
+
1108
+ def compute_line_graph(g: nx.Graph, repopulate_data: bool = True) -> nx.Graph:
1109
+ """Computes the line graph of a graph.
1110
+
1111
+ The line graph of a graph G has a node for each edge in G and an edge
1112
+ joining those nodes if the two edges in G share a common node. For directed
1113
+ graphs, nodes are adjacent exactly when the edges they represent form a
1114
+ directed path of length two.
1115
+
1116
+ The nodes of the line graph are 2-tuples of nodes in the original graph (or
1117
+ 3-tuples for multigraphs, with the key of the edge as the third element).
1118
+
1119
+ :param g: Graph to compute the line graph of.
1120
+ :type g: nx.Graph
1121
+ :param repopulate_data: Whether or not to map node and edge data to edges
1122
+ and nodes of the line graph, defaults to True
1123
+ :type repopulate_data: bool, optional
1124
+ :return: Line graph of g.
1125
+ :rtype: nx.Graph
1126
+ """
1127
+ l_g = nx.generators.line_graph(g)
1128
+ l_g.graph = g.graph
1129
+
1130
+ if repopulate_data:
1131
+ source_edge_data = {(u, v): d for u, v, d in g.edges(data=True)}
1132
+ nx.set_node_attributes(l_g, source_edge_data)
1133
+
1134
+ node_list = {}
1135
+ for u, v, d in l_g.edges(data=True):
1136
+ node_union = u + v
1137
+ for n in node_union:
1138
+ if node_union.count(n) > 1:
1139
+ node_list[(u, v)] = n
1140
+ break
1141
+
1142
+ source_node_data = {k: g.nodes[v] for k, v in node_list.items()}
1143
+ nx.set_edge_attributes(l_g, source_node_data)
1144
+ return l_g
modeling_prot2text.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config, AutoTokenizer, GPT2Config
2
+ from transformers import PretrainedConfig, PreTrainedModel
3
+ import transformers
4
+ from typing import Optional, Tuple, Callable
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
8
+ from .utils import CABlock, _GPT2LMHeadModel
9
+ from .configuration_prot2text import Prot2TextConfig
10
+ import os
11
+ import numpy as np
12
+ from transformers.generation.configuration_utils import GenerationConfig
13
+ from transformers.generation.logits_process import LogitsProcessorList
14
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
15
+
16
+ from .pdb2graph import PDB2Graph, download_alphafold_structure
17
+ from .graphs import *
18
+ from .utils_dataset import *
19
+
20
+ try:
21
+ from graphein.protein.config import ProteinGraphConfig, DSSPConfig
22
+ from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor
23
+ from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure
24
+ from graphein.protein.edges.distance import (add_peptide_bonds,
25
+ add_hydrogen_bond_interactions,
26
+ add_distance_threshold,
27
+ )
28
+ except ImportError:
29
+ raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')
30
+
31
+ try:
32
+ from torch_geometric.nn import RGCNConv, global_mean_pool
33
+ except ImportError:
34
+ raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')
35
+
36
+
37
+
38
+ class EncoderRGCN(PreTrainedModel):
39
+ '''
40
+ This class implement the RGCN encoder to encode the protein structure
41
+ '''
42
+ def __init__(self, input_dim, hidden_dim=512, n_layers=6, emb_dim=512, dropout=0.2, num_relation=7, prot2text_version='1.0'):
43
+ super(EncoderRGCN, self).__init__(PretrainedConfig(name='RGCN'))
44
+ self.n_layers = n_layers
45
+ self.output_dim = emb_dim
46
+ self.prot2text_version = prot2text_version
47
+
48
+ self.fc0 = nn.Linear(input_dim, hidden_dim)
49
+ self.batchnorm_final = nn.BatchNorm1d(hidden_dim)
50
+
51
+ self.batch_norms = nn.ModuleList()
52
+ self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
53
+ lst = list()
54
+
55
+ lst.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relation))
56
+
57
+ for i in range(n_layers-1):
58
+ lst.append(RGCNConv(hidden_dim,hidden_dim, num_relations=num_relation))
59
+
60
+ self.conv = nn.ModuleList(lst)
61
+
62
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim)
63
+ self.fc2 = nn.Linear(hidden_dim, self.output_dim)
64
+
65
+ self.dropout = nn.Dropout(p=dropout)
66
+ self.relu = nn.LeakyReLU()
67
+ self.batchnorm = nn.BatchNorm1d(hidden_dim)
68
+ self.main_input_name = 'nothing'
69
+
70
+ def forward(self, x:Optional[torch.FloatTensor] = None,
71
+ edge_index:Optional[torch.LongTensor] = None,
72
+ edge_type:Optional[torch.LongTensor] = None,
73
+ batch:Optional[torch.LongTensor] = None,
74
+ **kargs):
75
+ #construct pyg edge index shape (2, num_edges) from edge_list
76
+ x = self.relu(self.fc0(x))
77
+
78
+ for i in range(self.n_layers):
79
+ x = self.conv[i](x, edge_index, edge_type)
80
+
81
+ out = global_mean_pool(x, batch)
82
+ out = self.relu(self.fc1(out))
83
+ out = self.relu(self.fc2(out))
84
+
85
+ return out.unsqueeze(1)
86
+
87
+ class Prot2TextModel(PreTrainedModel):
88
+ config_class = Prot2TextConfig
89
+ _keys_to_ignore_on_load_missing = [r"transformer"]
90
+ base_model_prefix = "decoder"
91
+ def __init__(self, config):
92
+ super().__init__(config)
93
+
94
+ self.gpt_config = GPT2Config.from_dict(config.gpt_config)
95
+
96
+ # if we are using RGCN to encode the protein's structure, define the RGCN encoder
97
+ if config.rgcn:
98
+ self.encoder = EncoderRGCN(input_dim=config.rgcn_input_dim, hidden_dim=self.gpt_config.n_embd, n_layers=config.rgcn_n_layers, emb_dim=self.gpt_config.n_embd, prot2text_version=self.config.prot2text_version)
99
+
100
+ # define the GPT2 decoder
101
+ self.decoder = _GPT2LMHeadModel(self.gpt_config)
102
+
103
+ # if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer
104
+ if config.esm:
105
+ self.esm_config = PretrainedConfig.from_dict(config.esm_config)
106
+ self.esm = transformers.EsmModel(self.esm_config)
107
+ self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd)
108
+ if config.cross_esm_graph and config.rgcn:
109
+ self.h = nn.ModuleList([CABlock(self.gpt_config, layer_idx=i) for i in range(4)])
110
+ self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon)
111
+
112
+ self.config = config
113
+
114
+
115
+ def get_encoder(self):
116
+ return self.encoder
117
+
118
+ def get_decoder(self):
119
+ return self.decoder
120
+
121
+ def get_input_embeddings(self):
122
+ if hasattr(self, "transformer"):
123
+ return self.transformer.wte
124
+ return self.decoder.transformer.wte
125
+
126
+ def warm_up(self, gpt_model=None, esm_model=None):
127
+ if esm_model is not None:
128
+ self.esm = transformers.EsmModel.from_pretrained(esm_model)
129
+ if gpt_model is not None:
130
+ self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False)
131
+ self.decoder.resize_token_embeddings(self.gpt_config.vocab_size)
132
+ self.decoder.config = self.gpt_config
133
+
134
+
135
+ def forward(self,
136
+ encoder_input_ids: Optional[torch.LongTensor] = None,
137
+ edge_index: Optional[torch.LongTensor] = None,
138
+ batch: Optional[torch.LongTensor] = None,
139
+ x: Optional[torch.FloatTensor] = None,
140
+ edge_type: Optional[torch.LongTensor] = None,
141
+ decoder_input_ids: Optional[torch.LongTensor] = None,
142
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
143
+ past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None,
144
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
145
+ attention_mask: Optional[torch.FloatTensor] = None,
146
+ token_type_ids: Optional[torch.LongTensor] = None,
147
+ position_ids: Optional[torch.LongTensor] = None,
148
+ head_mask: Optional[torch.FloatTensor] = None,
149
+ inputs_embeds: Optional[torch.FloatTensor] = None,
150
+ encoder_hidden_states: Optional[torch.Tensor] = None,
151
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
152
+ labels: Optional[torch.LongTensor] = None,
153
+ use_cache: Optional[bool] = None,
154
+ output_attentions: Optional[bool] = None,
155
+ output_hidden_states: Optional[bool] = None,
156
+ return_dict: Optional[bool] = None,
157
+ get_graph_emb: Optional[bool] = False,
158
+ **delete_args,
159
+ ):
160
+ use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache
161
+ return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict
162
+
163
+
164
+ if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
165
+ decoder_input_ids = decoder_input_ids.squeeze(0)
166
+
167
+ if x is not None and self.config.rgcn:
168
+ graph_emb = self.encoder(x, edge_index, edge_type, batch)
169
+ graph_mask = None
170
+
171
+ if self.config.esm:
172
+ if self.config.prot2text_version=='1.0':
173
+ if encoder_input_ids.size()[1] != 1021:
174
+ raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021")
175
+
176
+ esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
177
+ esm_emb = self.to_embedding(esm_emb)
178
+ if not self.config.cross_esm_graph and self.config.rgcn:
179
+ graph_emb = torch.cat((graph_emb, esm_emb), dim=1)
180
+ t_add = torch.ones((attention_mask.size(0), 1)).to(attention_mask.get_device())
181
+ attention_mask = torch.cat((t_add, attention_mask), dim=1)
182
+ elif self.config.cross_esm_graph and self.config.rgcn:
183
+ if past_key_values_graph_esm is None:
184
+ past_length = 0
185
+ past_key_values_graph_esm = tuple([None] * len(self.h))
186
+ else:
187
+ past_length = past_key_values_graph_esm[0][0].size(-2)
188
+ output_shape = esm_emb.size()
189
+
190
+ all_self_attentions = () if output_attentions else None
191
+ all_cross_attentions = () if output_attentions and self.gpt_config.add_cross_attention else None
192
+ all_hidden_states = () if output_hidden_states else None
193
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values_graph_esm)):
194
+ outputs = block(
195
+ esm_emb,
196
+ layer_past=layer_past,
197
+ attention_mask=attention_mask,
198
+ encoder_hidden_states=graph_emb,
199
+ encoder_attention_mask=graph_mask,
200
+ use_cache=use_cache,
201
+ output_attentions=False,
202
+ )
203
+ esm_emb = outputs[0]
204
+
205
+ esm_emb = self.ln_f(esm_emb)
206
+ esm_emb = esm_emb.view(output_shape)
207
+ graph_emb = esm_emb
208
+ else:
209
+ graph_emb = esm_emb
210
+ else:
211
+ attention_mask = None
212
+ if self.config.prot2text_version=='1.0':
213
+ attention_mask = None
214
+ if get_graph_emb:
215
+ return graph_emb
216
+
217
+ transformer_outputs = self.decoder(input_ids=decoder_input_ids,
218
+ past_key_values=past_key_values,
219
+ attention_mask=decoder_attention_mask,
220
+ token_type_ids=token_type_ids,
221
+ position_ids=position_ids,
222
+ head_mask=head_mask,
223
+ inputs_embeds=inputs_embeds,
224
+ encoder_hidden_states=graph_emb,
225
+ encoder_attention_mask=attention_mask,
226
+ labels=labels,
227
+ use_cache=use_cache,
228
+ output_attentions=output_attentions,
229
+ output_hidden_states=output_hidden_states,
230
+ return_dict=return_dict,
231
+ )
232
+
233
+ return transformer_outputs
234
+
235
+ @torch.no_grad()
236
+ def generate_protein_description(self,
237
+ protein_pdbID=None,
238
+ protein_sequence=None,
239
+ edge_index: Optional[torch.LongTensor] = None,
240
+ x: Optional[torch.FloatTensor] = None,
241
+ edge_type: Optional[torch.LongTensor] = None,
242
+ tokenizer=None,
243
+ device='cpu'
244
+ ):
245
+
246
+ if self.config.esm and not self.config.rgcn and protein_sequence==None:
247
+ raise ValueError(
248
+ "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
249
+ )
250
+ if self.config.rgcn and protein_pdbID==None and (x==None or edge_index==None or edge_type==None):
251
+ raise ValueError(
252
+ "The model you are trying to use is based on protein structure, please provide a AlphaFold ID (you must have to have internet connection using protein_pdbID, or provide the triplet inputs: x (node features), edge_index and edge_type"
253
+ )
254
+ if self.config.esm:
255
+ esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
256
+
257
+ if protein_pdbID==None and protein_sequence==None:
258
+ raise ValueError(
259
+ "you need to provide either a protein AlphaFold Id or an amino-acid sequence"
260
+ )
261
+
262
+ if protein_pdbID!=None:
263
+ config = {"node_metadata_functions": [amino_acid_one_hot,
264
+ expasy_protein_scale,
265
+ meiler_embedding,
266
+ hydrogen_bond_acceptor, hydrogen_bond_donor
267
+ ],
268
+ "edge_construction_functions": [add_peptide_bonds,
269
+ add_hydrogen_bond_interactions,
270
+ partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.),],
271
+ "graph_metadata_functions":[asa,phi, psi, secondary_structure, rsa],
272
+ "dssp_config": DSSPConfig()}
273
+ config = ProteinGraphConfig(**config)
274
+
275
+ PATH_TO_DATA = f"~/.tmp/pdb/pdb"
276
+ OUTPUT_FOLDER = f"~/.tmp/pdb/raw"
277
+ save_dir = f"~/.tmp/pdb/"
278
+ isExist = os.path.exists(PATH_TO_DATA)
279
+ if not isExist:
280
+ os.makedirs(PATH_TO_DATA)
281
+ isExist = os.path.exists(OUTPUT_FOLDER)
282
+ if not isExist:
283
+ os.makedirs(OUTPUT_FOLDER)
284
+ isExist = os.path.exists(save_dir+'processed')
285
+ if not isExist:
286
+ os.makedirs(save_dir+'processed')
287
+
288
+ structure_filename = download_alphafold_structure(uniprot_id=protein_pdbID, out_dir=PATH_TO_DATA)
289
+ if structure_filename is None:
290
+ raise ValueError("Error! the ID does not exist in AlphaFoldDB or you do not have internet connection")
291
+ graph_filename = structure_filename.split('/')
292
+ graph_filename[-2] = 'raw'
293
+ graph_filename[-1] = graph_filename[-1].replace('.pdb', '.pt')
294
+ graph_filename = '/'.join(graph_filename)
295
+ process_filename = structure_filename.split('/')
296
+ process_filename[-2] = 'processed'
297
+ process_filename[-1] = process_filename[-1].replace('.pdb', '.pt')
298
+ process_filename = '/'.join(process_filename)
299
+ try:
300
+ gpdb = PDB2Graph(root = PATH_TO_DATA, output_folder = OUTPUT_FOLDER, config=config, n_processors=1).create_pyg_graph(structure_filename)
301
+ seq = esmtokenizer(gpdb.sequence, add_special_tokens=True, truncation=True, max_length=1021, padding='max_length',return_tensors="pt") #
302
+ torch.save(gpdb, graph_filename)
303
+ gpdb.edge_type = [np.array(gpdb.edge_type.transpose(0,1))]
304
+ gpdb.encoder_input_ids = seq['input_ids']
305
+ gpdb.attention_mask = seq['attention_mask']
306
+ torch.save(gpdb, process_filename)
307
+ except:
308
+ os.remove(structure_filename)
309
+ raise ValueError('creating graphs did not work, probably the pdb file of alphaFold is damaged')
310
+
311
+ self.eval()
312
+ inputs = gpdb
313
+ inputs = inputs.to_dict()
314
+
315
+ inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0)
316
+ inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1)
317
+ for key in ['num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates']:
318
+ inputs.pop(key)
319
+ inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
320
+ inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
321
+ inputs["decoder_attention_mask"] = torch.ones(inputs['decoder_input_ids'].shape[0], 1)
322
+ self.to(device)
323
+ inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
324
+ encoder_state = dict()
325
+ encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
326
+ encoder_state['attentions'] = inputs['attention_mask']
327
+ for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
328
+ inputs.pop(key)
329
+ tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
330
+ encoder_outputs=encoder_state,
331
+ use_cache=True,
332
+ output_attentions=False,
333
+ output_scores=False,
334
+ return_dict_in_generate=True,
335
+ encoder_attention_mask=inputs['attention_mask'],
336
+ length_penalty=1.0,
337
+ no_repeat_ngram_size=None,
338
+ early_stopping=False,
339
+ num_beams=1)
340
+
341
+ generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
342
+
343
+ os.remove(structure_filename)
344
+ os.remove(graph_filename)
345
+ os.remove(process_filename)
346
+
347
+ return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
348
+
349
+ else:
350
+ seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
351
+ inputs={}
352
+ inputs['encoder_input_ids'] = seq['input_ids']
353
+ inputs['attention_mask'] = seq['attention_mask']
354
+ inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
355
+ inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
356
+
357
+ self.to(device)
358
+ inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
359
+ encoder_state = dict()
360
+ encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
361
+ generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
362
+
363
+ return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
364
+
365
+ @torch.no_grad()
366
+ def generate(self,
367
+ inputs: Optional[torch.Tensor] = None,
368
+ generation_config: Optional[GenerationConfig] = None,
369
+ logits_processor: Optional[LogitsProcessorList] = None,
370
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
371
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
372
+ synced_gpus: Optional[bool] = None,
373
+ assistant_model: Optional["PreTrainedModel"] = None,
374
+ streamer: Optional["BaseStreamer"] = None,
375
+ **kwargs,
376
+ ):
377
+ encoder_state = self(**kwargs, get_graph_emb=True)
378
+ input_ids = kwargs['decoder_input_ids']
379
+ attention_mask = kwargs['decoder_attention_mask']
380
+ kwargs['encoder_attention_mask'] = kwargs['attention_mask']
381
+ if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm:
382
+ t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device())
383
+ kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1)
384
+ for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length',
385
+ '_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]:
386
+ if key in kwargs.keys():
387
+ kwargs.pop(key)
388
+ return self.decoder.generate(input_ids=input_ids,
389
+ generation_config=generation_config,
390
+ logits_processor=logits_processor,
391
+ stopping_criteria=stopping_criteria,
392
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
393
+ synced_gpus=synced_gpus,
394
+ assistant_model=assistant_model,
395
+ streamer=streamer,
396
+ encoder_outputs={'hidden_states': encoder_state, 'attentions':0},
397
+ **kwargs
398
+ )
pdb2graph.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import os
3
+ from tqdm import tqdm
4
+ from sklearn.preprocessing import MultiLabelBinarizer
5
+
6
+ try:
7
+ from torch_geometric.data import Data
8
+ except ImportError:
9
+ raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')
10
+ import torch
11
+
12
+ import numpy as np
13
+
14
+ from .conversion import convert_nx_to_pyg_data
15
+
16
+ try:
17
+ from graphein.protein.config import ProteinGraphConfig, DSSPConfig
18
+ from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor
19
+ from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure
20
+ from graphein.protein.edges.distance import (add_peptide_bonds,
21
+ add_hydrogen_bond_interactions,
22
+ add_disulfide_interactions,
23
+ add_ionic_interactions,
24
+ add_delaunay_triangulation,
25
+ add_distance_threshold,
26
+ add_sequence_distance_edges,
27
+ add_k_nn_edges)
28
+ except ImportError:
29
+ raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html')
30
+
31
+ from functools import partial
32
+ from .graphs import *
33
+ from .utils_dataset import *
34
+ import os
35
+ import sys
36
+ import subprocess
37
+ import wget
38
+
39
+
40
+ class PDB2Graph():
41
+ def __init__(self, root, output_folder, config, n_processors=int(multiprocessing.cpu_count())):
42
+ self.root = root
43
+ self.output_folder = output_folder
44
+ self.map_secondary_structure = {'-':0, 'H':1, 'B':2, 'E':3, 'G':4, 'I':5, 'T':6, 'S':7}
45
+ self.init_ohe_edge_type()
46
+ self.config = config
47
+ self.features = ['phi', 'psi', 'rsa', 'asa', 'ss', 'expasy']
48
+ self.n_processors = n_processors
49
+ self.raw_dir = root
50
+ self.processed_dir = self._processed_dir()
51
+ self.raw_file_names = self._raw_file_names()
52
+ self.processed_file_names = self._processed_file_names()
53
+
54
+
55
+ def _processed_dir(self):
56
+ #processed_dir = os.path.join(os.path.split(self.root)[0], "processed_new")
57
+ if not os.path.exists(self.output_folder):
58
+ os.makedirs(self.output_folder)
59
+ return self.output_folder
60
+
61
+ def _raw_file_names(self):
62
+ return os.listdir(self.raw_dir)
63
+
64
+ def _processed_file_names(self):
65
+ return [self.pdb2pathdata(pdb_path.split(".")[0]) for pdb_path in self.raw_file_names]
66
+
67
+ def create_nx_graph(self, path_to_structure):
68
+ return construct_graph(self.config, pdb_path = path_to_structure)
69
+
70
+ def create_pyg_graph(self, path_to_structure):
71
+ pyg_graph = convert_nx_to_pyg_data(self.create_nx_graph(path_to_structure))
72
+
73
+ graph = Data(edge_index = pyg_graph.edge_index,
74
+ num_nodes = len(pyg_graph.node_id),
75
+ node_id = pyg_graph.node_id,
76
+ name = pyg_graph.name[0],
77
+ sequence = getattr(pyg_graph, f"sequence_{pyg_graph.chain_id[0]}"),
78
+ distance_matrix = pyg_graph.dist_mat,
79
+ distance = pyg_graph.distance,
80
+ coordinates = torch.FloatTensor(np.array(pyg_graph.coords[0])))
81
+ #create the features
82
+ x = np.array([np.argmax(pyg_graph.amino_acid_one_hot, axis=1)]).reshape(-1,1)
83
+ for feat in self.features:
84
+ if feat == "ss":
85
+ feature = np.array([[self.map_secondary_structure.get(feat_node, 0)] \
86
+ for feat_node in pyg_graph[feat]])
87
+ else:
88
+ feature = np.array(pyg_graph[feat])
89
+ if len(feature.shape) == 1:
90
+ feature = feature.reshape(-1,1)
91
+ x = np.concatenate((x, feature), axis = 1)
92
+ graph.edge_type = self.mlb.transform(pyg_graph.kind)
93
+ graph.x = torch.FloatTensor(x)
94
+ # y = self.annotations[graph.name.split("_")[0]]
95
+ # if self.task == 'GeneOntology' :
96
+ # graph.y_mf = torch.FloatTensor(y["mf"])
97
+ # graph.y_cc = torch.FloatTensor(y["cc"])
98
+ # graph.y_bp = torch.FloatTensor(y["bp"])
99
+ # else:
100
+ # graph.y_ec = torch.FloatTensor(y["ec"])
101
+ return graph
102
+
103
+ def init_ohe_edge_type(self):
104
+ self.mlb = MultiLabelBinarizer(classes = ['peptide_bond', 'sequence_distance_2', 'sequence_distance_3'
105
+ , 'distance_threshold', 'delaunay', 'hbond', 'k_nn'])
106
+ self.mlb.fit([['peptide_bond', 'sequence_distance_2', 'sequence_distance_3'
107
+ , 'distance_threshold', 'delaunay', 'hbond', 'k_nn']])
108
+
109
+ def process(self):
110
+ """Convert the PDB files into torch geometric graphs"""
111
+ # self.pdb2graph = PDB2Graph(self.config)
112
+ to_be_processed = self.get_files_to_process()
113
+
114
+ # pool = multiprocessing.Pool(self.n_processors)
115
+ # for _ in tqdm(pool.imap_unordered(self.graph_creation, to_be_processed), total=len(to_be_processed)):
116
+ # continue
117
+ # pool.close()
118
+ # pool.join()
119
+
120
+
121
+
122
+ processes = []
123
+ for prot in tqdm(to_be_processed):
124
+ p = multiprocessing.Process(target=self.graph_creation, args=(prot,))
125
+ processes.append(p)
126
+ p.start()
127
+
128
+ for process in processes:
129
+ process.join()
130
+
131
+
132
+ def graph_creation(self, pdb):
133
+ """Create a graph from the PDB file"""
134
+
135
+ # Define the path_to_structure from the pdb name file
136
+ path_to_structure = self.pdb2pathstructure(pdb)
137
+
138
+ # Convert the structure into a graph
139
+ g = self.create_pyg_graph(path_to_structure)
140
+ # Save the graph
141
+ torch.save(g, os.path.join(self.output_folder, self.pdb2pathdata(pdb)))
142
+
143
+ return None
144
+
145
+ def pdb2pathdata(self, pdb):
146
+ return pdb+'.pt'
147
+
148
+ def pdb2pathstructure(self, pdb):
149
+ return os.path.join(self.raw_dir, pdb+'.pdb')
150
+
151
+ def get_files_to_process(self):
152
+ RAW_FILES = self.processed_file_names
153
+ PROCESSED_FILES = os.listdir(self.processed_dir)
154
+ to_be_processed = set(RAW_FILES).difference(set(PROCESSED_FILES))
155
+ to_be_processed = [path.split('.')[0] for path in to_be_processed]
156
+ return to_be_processed
157
+
158
+ def download_alphafold_structure(
159
+ uniprot_id: str,
160
+ out_dir: str,
161
+ version: int = 4
162
+ ):
163
+
164
+ BASE_URL = "https://alphafold.ebi.ac.uk/files/"
165
+ uniprot_id = uniprot_id.upper()
166
+
167
+ query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.pdb"
168
+ structure_filename = os.path.join(out_dir, f"AF-{uniprot_id}-F1-model_v{version}.pdb")
169
+ if os.path.exists(structure_filename):
170
+ return structure_filename
171
+ try:
172
+ structure_filename = wget.download(query_url, out=out_dir)
173
+ except:
174
+ print('Error.. could not download: ', f"AF-{uniprot_id}-F1-model_v{version}.pdb")
175
+ return None
176
+ return structure_filename
177
+
178
+
utils.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
3
+ from typing import Optional, Tuple, Union, Any, Dict, List
4
+ from transformers import Seq2SeqTrainer, GPT2LMHeadModel
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ import torch
7
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
8
+ from transformers.generation.logits_process import LogitsProcessorList
9
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
10
+ from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput
11
+ from transformers.generation.beam_search import BeamScorer
12
+
13
+ try:
14
+ from torch_geometric.loader import DataLoader
15
+ from torch_geometric.data import Dataset
16
+ except ImportError:
17
+ raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html')
18
+
19
+ class _GPT2LMHeadModel(GPT2LMHeadModel):
20
+ def _init_(self, config):
21
+ super(GPT2LMHeadModel, self).init_(config)
22
+ self.config = config
23
+
24
+
25
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, encoder_outputs=None, **kwargs):
26
+ '''
27
+ This function is an edited version of the prepare_inputs_for_generation function from HuggingFace's transformers
28
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
29
+ '''
30
+ token_type_ids = kwargs.get("token_type_ids", None)
31
+ # only last token for inputs_ids if past is defined in kwargs
32
+ if past_key_values:
33
+ input_ids = input_ids[:, -1].unsqueeze(-1)
34
+ if token_type_ids is not None:
35
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
36
+
37
+ attention_mask = kwargs.get("attention_mask", None)
38
+ position_ids = kwargs.get("position_ids", None)
39
+ if self.config.prot2text_version=="1.1" or self.config.prot2text_version=="1.2":
40
+ encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
41
+ elif self.config.prot2text_version=="1.0":
42
+ encoder_attention_mask = None
43
+
44
+ if attention_mask is not None and position_ids is None:
45
+ position_ids = attention_mask.long().cumsum(-1) - 1
46
+ position_ids.masked_fill_(attention_mask == 0, 1)
47
+ if past_key_values:
48
+ position_ids = position_ids[:, -1].unsqueeze(-1)
49
+ else:
50
+ position_ids = None
51
+
52
+ model_specific_kwargs = {
53
+ "encoder_hidden_states": encoder_outputs['hidden_states'],
54
+ }
55
+
56
+ return {
57
+ "input_ids": input_ids,
58
+ "past_key_values": past_key_values,
59
+ "use_cache": kwargs.get("use_cache"),
60
+ "position_ids": position_ids,
61
+ "attention_mask": attention_mask,
62
+ "token_type_ids": token_type_ids,
63
+ "encoder_attention_mask": encoder_attention_mask,
64
+ **model_specific_kwargs
65
+ }
66
+
67
+
68
+ def greedy_search(
69
+ self,
70
+ input_ids: torch.LongTensor,
71
+ logits_processor: Optional[LogitsProcessorList] = None,
72
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
73
+ max_length: Optional[int] = None,
74
+ pad_token_id: Optional[int] = None,
75
+ eos_token_id: Optional[Union[int, List[int]]] = None,
76
+ output_attentions: Optional[bool] = None,
77
+ output_hidden_states: Optional[bool] = None,
78
+ output_scores: Optional[bool] = None,
79
+ return_dict_in_generate: Optional[bool] = None,
80
+ synced_gpus: bool = False,
81
+ streamer: Optional["BaseStreamer"] = None,
82
+ **model_kwargs,
83
+ ) -> Union[GreedySearchOutput, torch.LongTensor]:
84
+ '''
85
+ This function is an edited version of the greedy_search function from HuggingFace's transformers
86
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
87
+ '''
88
+
89
+ # init values
90
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
91
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
92
+ if max_length is not None:
93
+ warnings.warn(
94
+ "`max_length` is deprecated in this function, use"
95
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
96
+ UserWarning,
97
+ )
98
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
99
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
100
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
101
+ if isinstance(eos_token_id, int):
102
+ eos_token_id = [eos_token_id]
103
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
104
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
105
+ output_attentions = (
106
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
107
+ )
108
+ output_hidden_states = (
109
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
110
+ )
111
+ return_dict_in_generate = (
112
+ return_dict_in_generate
113
+ if return_dict_in_generate is not None
114
+ else self.generation_config.return_dict_in_generate
115
+ )
116
+
117
+ # init attention / hidden states / scores tuples
118
+ scores = () if (return_dict_in_generate and output_scores) else None
119
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
120
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
121
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
122
+
123
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
124
+ if return_dict_in_generate and self.config.is_encoder_decoder:
125
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
126
+ encoder_hidden_states = (
127
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
128
+ )
129
+
130
+ # keep track of which sequences are already finished
131
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
132
+
133
+ this_peer_finished = False # used by synced_gpus only
134
+ while True:
135
+ if synced_gpus:
136
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
137
+ # The following logic allows an early break if all peers finished generating their sequence
138
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
139
+ # send 0.0 if we finished, 1.0 otherwise
140
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
141
+ # did all peers finish? the reduced sum will be 0.0 then
142
+ if this_peer_finished_flag.item() == 0.0:
143
+ break
144
+
145
+ # prepare model inputs
146
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
147
+
148
+ # forward pass to get next token
149
+ outputs = self(
150
+ **model_inputs,
151
+ return_dict=True,
152
+ output_attentions=output_attentions,
153
+ output_hidden_states=output_hidden_states,
154
+ )
155
+
156
+ if synced_gpus and this_peer_finished:
157
+ continue # don't waste resources running the code we don't need
158
+
159
+ next_token_logits = outputs.logits[:, -1, :]
160
+
161
+ # pre-process distribution
162
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
163
+
164
+ # Store scores, attentions and hidden_states when required
165
+ if return_dict_in_generate:
166
+ if output_scores:
167
+ scores += (next_tokens_scores,)
168
+ if output_attentions:
169
+ decoder_attentions += (
170
+ (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,)
171
+ )
172
+ if self.config.is_encoder_decoder:
173
+ cross_attentions += (outputs.cross_attentions,)
174
+
175
+ if output_hidden_states:
176
+ decoder_hidden_states += (
177
+ (outputs.decoder_hidden_states,)
178
+ if self.config.is_encoder_decoder
179
+ else (outputs.hidden_states,)
180
+ )
181
+
182
+ # argmax
183
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
184
+
185
+ # finished sentences should have their next token be a padding token
186
+ if eos_token_id is not None:
187
+ if pad_token_id is None:
188
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
189
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
190
+
191
+ # update generated ids, model inputs, and length for next step
192
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
193
+ if streamer is not None:
194
+ streamer.put(next_tokens.cpu())
195
+ model_kwargs = self._update_model_kwargs_for_generation(
196
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
197
+ )
198
+
199
+ # if eos_token was found in one sentence, set sentence to finished
200
+ if eos_token_id_tensor is not None:
201
+ unfinished_sequences = unfinished_sequences.mul(
202
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
203
+ )
204
+
205
+ # stop when each sentence is finished
206
+ if unfinished_sequences.max() == 0:
207
+ this_peer_finished = True
208
+
209
+ # stop if we exceed the maximum length
210
+ try:
211
+ if stopping_criteria(input_ids, scores):
212
+ this_peer_finished = True
213
+ except:
214
+ if all(stopping_criteria(input_ids, scores)):
215
+ this_peer_finished = True
216
+
217
+ if this_peer_finished and not synced_gpus:
218
+ break
219
+
220
+ if streamer is not None:
221
+ streamer.end()
222
+
223
+ if return_dict_in_generate:
224
+ if self.config.is_encoder_decoder:
225
+ return GreedySearchEncoderDecoderOutput(
226
+ sequences=input_ids,
227
+ scores=scores,
228
+ encoder_attentions=encoder_attentions,
229
+ encoder_hidden_states=encoder_hidden_states,
230
+ decoder_attentions=decoder_attentions,
231
+ cross_attentions=cross_attentions,
232
+ decoder_hidden_states=decoder_hidden_states,
233
+ )
234
+ else:
235
+ return GreedySearchDecoderOnlyOutput(
236
+ sequences=input_ids,
237
+ scores=scores,
238
+ attentions=decoder_attentions,
239
+ hidden_states=decoder_hidden_states,
240
+ )
241
+ else:
242
+ return input_ids
243
+
244
+ def _greedy_search(
245
+ self,
246
+ input_ids: torch.LongTensor,
247
+ logits_processor: Optional[LogitsProcessorList] = None,
248
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
249
+ max_length: Optional[int] = None,
250
+ pad_token_id: Optional[int] = None,
251
+ eos_token_id: Optional[Union[int, List[int]]] = None,
252
+ output_attentions: Optional[bool] = None,
253
+ output_hidden_states: Optional[bool] = None,
254
+ output_scores: Optional[bool] = None,
255
+ return_dict_in_generate: Optional[bool] = None,
256
+ synced_gpus: bool = False,
257
+ streamer: Optional["BaseStreamer"] = None,
258
+ **model_kwargs,
259
+ ) -> Union[GreedySearchOutput, torch.LongTensor]:
260
+
261
+ return self.greedy_search(
262
+ input_ids,
263
+ logits_processor,
264
+ stopping_criteria,
265
+ max_length,
266
+ pad_token_id,
267
+ eos_token_id,
268
+ output_attentions,
269
+ output_hidden_states,
270
+ output_scores,
271
+ return_dict_in_generate,
272
+ synced_gpus,
273
+ streamer,
274
+ **model_kwargs,
275
+ )
276
+ def _beam_search(
277
+ self,
278
+ input_ids: torch.LongTensor,
279
+ beam_scorer: BeamScorer,
280
+ logits_processor: Optional[LogitsProcessorList] = None,
281
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
282
+ max_length: Optional[int] = None,
283
+ pad_token_id: Optional[int] = None,
284
+ eos_token_id: Optional[Union[int, List[int]]] = None,
285
+ output_attentions: Optional[bool] = None,
286
+ output_hidden_states: Optional[bool] = None,
287
+ output_scores: Optional[bool] = None,
288
+ return_dict_in_generate: Optional[bool] = None,
289
+ synced_gpus: bool = False,
290
+ **model_kwargs,
291
+ ) -> Union[BeamSearchOutput, torch.LongTensor]:
292
+
293
+ return self.beam_search(
294
+ input_ids,
295
+ beam_scorer,
296
+ logits_processor,
297
+ stopping_criteria,
298
+ max_length,
299
+ pad_token_id,
300
+ eos_token_id,
301
+ output_attentions,
302
+ output_hidden_states,
303
+ output_scores,
304
+ return_dict_in_generate,
305
+ synced_gpus,
306
+ **model_kwargs,
307
+ )
308
+
309
+ def beam_search(
310
+ self,
311
+ input_ids: torch.LongTensor,
312
+ beam_scorer: BeamScorer,
313
+ logits_processor: Optional[LogitsProcessorList] = None,
314
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
315
+ max_length: Optional[int] = None,
316
+ pad_token_id: Optional[int] = None,
317
+ eos_token_id: Optional[Union[int, List[int]]] = None,
318
+ output_attentions: Optional[bool] = None,
319
+ output_hidden_states: Optional[bool] = None,
320
+ output_scores: Optional[bool] = None,
321
+ return_dict_in_generate: Optional[bool] = None,
322
+ synced_gpus: bool = False,
323
+ **model_kwargs,
324
+ ) -> Union[BeamSearchOutput, torch.LongTensor]:
325
+ '''
326
+ This function is an edited version of the beam_search function from HuggingFace's transformers
327
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
328
+ '''
329
+ # init values
330
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
331
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
332
+ if max_length is not None:
333
+ warnings.warn(
334
+ "`max_length` is deprecated in this function, use"
335
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
336
+ UserWarning,
337
+ )
338
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
339
+ if len(stopping_criteria) == 0:
340
+ warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
341
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
342
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
343
+ if isinstance(eos_token_id, int):
344
+ eos_token_id = [eos_token_id]
345
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
346
+ output_attentions = (
347
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
348
+ )
349
+ output_hidden_states = (
350
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
351
+ )
352
+ return_dict_in_generate = (
353
+ return_dict_in_generate
354
+ if return_dict_in_generate is not None
355
+ else self.generation_config.return_dict_in_generate
356
+ )
357
+
358
+ batch_size = len(beam_scorer._beam_hyps)
359
+ num_beams = beam_scorer.num_beams
360
+
361
+ batch_beam_size, cur_len = input_ids.shape
362
+
363
+ if num_beams * batch_size != batch_beam_size:
364
+ raise ValueError(
365
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
366
+ )
367
+
368
+ # init attention / hidden states / scores tuples
369
+ scores = () if (return_dict_in_generate and output_scores) else None
370
+ beam_indices = (
371
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
372
+ )
373
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
374
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
375
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
376
+
377
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
378
+ if return_dict_in_generate and self.config.is_encoder_decoder:
379
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
380
+ encoder_hidden_states = (
381
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
382
+ )
383
+
384
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
385
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
386
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
387
+ beam_scores[:, 1:] = -1e9
388
+ beam_scores = beam_scores.view((batch_size * num_beams,))
389
+
390
+ this_peer_finished = False # used by synced_gpus only
391
+ while True:
392
+ if synced_gpus:
393
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
394
+ # The following logic allows an early break if all peers finished generating their sequence
395
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
396
+ # send 0.0 if we finished, 1.0 otherwise
397
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
398
+ # did all peers finish? the reduced sum will be 0.0 then
399
+ if this_peer_finished_flag.item() == 0.0:
400
+ break
401
+
402
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
403
+
404
+ outputs = self(
405
+ **model_inputs,
406
+ return_dict=True,
407
+ output_attentions=output_attentions,
408
+ output_hidden_states=output_hidden_states,
409
+ )
410
+
411
+ if synced_gpus and this_peer_finished:
412
+ cur_len = cur_len + 1
413
+ continue # don't waste resources running the code we don't need
414
+
415
+ next_token_logits = outputs.logits[:, -1, :]
416
+ # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
417
+ # cannot be generated both before and after the `nn.functional.log_softmax` operation.
418
+ # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
419
+ next_token_scores = nn.functional.log_softmax(
420
+ next_token_logits, dim=-1
421
+ ) # (batch_size * num_beams, vocab_size)
422
+
423
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
424
+ # next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
425
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
426
+ next_token_scores_processed
427
+ )
428
+
429
+ # Store scores, attentions and hidden_states when required
430
+ if return_dict_in_generate:
431
+ if output_scores:
432
+ scores += (next_token_scores_processed,)
433
+ if output_attentions:
434
+ decoder_attentions += (
435
+ (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,)
436
+ )
437
+ if self.config.is_encoder_decoder:
438
+ cross_attentions += (outputs.cross_attentions,)
439
+
440
+ if output_hidden_states:
441
+ decoder_hidden_states += (
442
+ (outputs.decoder_hidden_states,)
443
+ if self.config.is_encoder_decoder
444
+ else (outputs.hidden_states,)
445
+ )
446
+
447
+ # reshape for beam search
448
+ vocab_size = next_token_scores.shape[-1]
449
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
450
+
451
+
452
+
453
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
454
+ next_token_scores, next_tokens = torch.topk(
455
+ next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
456
+ )
457
+
458
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
459
+ next_tokens = next_tokens % vocab_size
460
+
461
+ # stateless
462
+ beam_outputs = beam_scorer.process(
463
+ input_ids,
464
+ next_token_scores,
465
+ next_tokens,
466
+ next_indices,
467
+ pad_token_id=pad_token_id,
468
+ eos_token_id=eos_token_id,
469
+ beam_indices=beam_indices,
470
+ )
471
+
472
+ beam_scores = beam_outputs["next_beam_scores"]
473
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
474
+ beam_idx = beam_outputs["next_beam_indices"]
475
+
476
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
477
+
478
+ model_kwargs = self._update_model_kwargs_for_generation(
479
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
480
+ )
481
+ if model_kwargs["past_key_values"] is not None:
482
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
483
+
484
+ if return_dict_in_generate and output_scores:
485
+ beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
486
+
487
+ # increase cur_len
488
+ cur_len = cur_len + 1
489
+
490
+ try:
491
+ if beam_scorer.is_done or stopping_criteria(input_ids, scores):
492
+ if not synced_gpus:
493
+ break
494
+ else:
495
+ this_peer_finished = True
496
+ except:
497
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
498
+ if not synced_gpus:
499
+ break
500
+ else:
501
+ this_peer_finished = True
502
+
503
+
504
+ sequence_outputs = beam_scorer.finalize(
505
+ input_ids,
506
+ beam_scores,
507
+ next_tokens,
508
+ next_indices,
509
+ pad_token_id=pad_token_id,
510
+ eos_token_id=eos_token_id,
511
+ max_length=stopping_criteria.max_length,
512
+ beam_indices=beam_indices,
513
+ )
514
+
515
+ if return_dict_in_generate:
516
+ if not output_scores:
517
+ sequence_outputs["sequence_scores"] = None
518
+
519
+ if self.config.is_encoder_decoder:
520
+ return BeamSearchEncoderDecoderOutput(
521
+ sequences=sequence_outputs["sequences"],
522
+ sequences_scores=sequence_outputs["sequence_scores"],
523
+ scores=scores,
524
+ beam_indices=sequence_outputs["beam_indices"],
525
+ encoder_attentions=encoder_attentions,
526
+ encoder_hidden_states=encoder_hidden_states,
527
+ decoder_attentions=decoder_attentions,
528
+ cross_attentions=cross_attentions,
529
+ decoder_hidden_states=decoder_hidden_states,
530
+ )
531
+ else:
532
+ return BeamSearchDecoderOnlyOutput(
533
+ sequences=sequence_outputs["sequences"],
534
+ sequences_scores=sequence_outputs["sequence_scores"],
535
+ scores=scores,
536
+ beam_indices=sequence_outputs["beam_indices"],
537
+ attentions=decoder_attentions,
538
+ hidden_states=decoder_hidden_states,
539
+ )
540
+ else:
541
+ return sequence_outputs["sequences"]
542
+
543
+
544
+ class CABlock(nn.Module):
545
+ '''
546
+ This function is an edited version of the gpt2 decoder block function from HuggingFace's transformers
547
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
548
+ '''
549
+ def __init__(self, config, layer_idx=None):
550
+ super().__init__()
551
+ hidden_size = config.hidden_size
552
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
553
+
554
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
555
+
556
+ self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
557
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
558
+
559
+ self.mlp = GPT2MLP(inner_dim, config)
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
564
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
565
+ attention_mask: Optional[torch.FloatTensor] = None,
566
+ head_mask: Optional[torch.FloatTensor] = None,
567
+ encoder_hidden_states: Optional[torch.Tensor] = None,
568
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
569
+ use_cache: Optional[bool] = False,
570
+ output_attentions: Optional[bool] = False,
571
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
572
+
573
+
574
+ residual = hidden_states
575
+ hidden_states = self.ln_cross_attn(hidden_states)
576
+ cross_attn_outputs = self.crossattention(
577
+ hidden_states,
578
+ attention_mask=attention_mask,
579
+ head_mask=head_mask,
580
+ encoder_hidden_states=encoder_hidden_states,
581
+ encoder_attention_mask=encoder_attention_mask,
582
+ output_attentions=output_attentions,
583
+ )
584
+ attn_output = cross_attn_outputs[0]
585
+ # residual connection
586
+ hidden_states = residual + attn_output
587
+
588
+ residual = hidden_states
589
+ hidden_states = self.ln_2(hidden_states)
590
+ feed_forward_hidden_states = self.mlp(hidden_states)
591
+ # residual connection
592
+ hidden_states = residual + feed_forward_hidden_states
593
+
594
+ return (hidden_states,)
595
+
596
+ class Prot2TextTrainer(Seq2SeqTrainer):
597
+ '''
598
+ This function is an edited version of the Seq2SeqTrainer from HuggingFace's transformers
599
+ '''
600
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
601
+ if self.args.world_size > 1:
602
+ eval_sampler = DistributedSampler(self.eval_dataset, num_replicas=self.args.world_size, rank=self.args.process_index)
603
+ else:
604
+ eval_sampler = None
605
+ return DataLoader(
606
+ self.eval_dataset,
607
+ batch_size=self.args.eval_batch_size,
608
+ collate_fn=None,
609
+ num_workers=self.args.dataloader_num_workers,
610
+ pin_memory=self.args.dataloader_pin_memory,
611
+ sampler=eval_sampler,
612
+ )
613
+ def get_train_dataloader(self) -> DataLoader:
614
+ if self.args.world_size > 1:
615
+ train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index)
616
+ else:
617
+ train_sampler = None
618
+ return DataLoader(
619
+ self.train_dataset,
620
+ batch_size=self.args.per_device_train_batch_size,
621
+ collate_fn=None,
622
+ num_workers=self.args.dataloader_num_workers,
623
+ pin_memory=self.args.dataloader_pin_memory,
624
+ sampler=train_sampler,
625
+ )
626
+ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
627
+ """
628
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
629
+ handling potential state.
630
+ """
631
+ inputs = self._prepare_input(inputs)
632
+ if len(inputs) == 0:
633
+ raise ValueError(
634
+ "The batch received was empty, your model won't be able to train on it. Double-check that your "
635
+ f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
636
+ )
637
+ if self.args.past_index >= 0 and self._past is not None:
638
+ inputs["mems"] = self._past
639
+
640
+ inputs = inputs.to_dict()
641
+ inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0)
642
+ inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1)
643
+ inputs = {k: v.to(device=self.args.device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
644
+ return inputs
645
+
646
+ def prediction_step(
647
+ self,
648
+ model: nn.Module,
649
+ inputs: Dict[str, Union[torch.Tensor, Any]],
650
+ prediction_loss_only: bool,
651
+ ignore_keys: Optional[List[str]] = None,
652
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
653
+ """
654
+ Perform an evaluation step on `model` using `inputs`.
655
+
656
+ Subclass and override to inject custom behavior.
657
+
658
+ Args:
659
+ model (`nn.Module`):
660
+ The model to evaluate.
661
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
662
+ The inputs and targets of the model.
663
+
664
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
665
+ argument `labels`. Check your model's documentation for all accepted arguments.
666
+ prediction_loss_only (`bool`):
667
+ Whether or not to return the loss only.
668
+
669
+ Return:
670
+ Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
671
+ labels (each being optional).
672
+ """
673
+
674
+ if not self.args.predict_with_generate or prediction_loss_only:
675
+ return super().prediction_step(
676
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
677
+ )
678
+
679
+ has_labels = "labels" in inputs
680
+ inputs = self._prepare_inputs(inputs)
681
+
682
+ # XXX: adapt synced_gpus for fairscale as well
683
+ gen_kwargs = self._gen_kwargs.copy()
684
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
685
+ gen_kwargs["max_length"] = self.model.config.max_length
686
+ gen_kwargs["num_beams"] = (
687
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
688
+ )
689
+ default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
690
+ gen_kwargs["synced_gpus"] = (
691
+ gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
692
+ )
693
+
694
+ if "attention_mask" in inputs:
695
+ gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
696
+ if "global_attention_mask" in inputs:
697
+ gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
698
+
699
+ generation_inputs = None
700
+ gen_kwargs['x'] = inputs.get('x', None)
701
+ gen_kwargs['edge_index'] = inputs.get('edge_index', None)
702
+ gen_kwargs['edge_type'] = inputs.get('edge_type', None)
703
+ gen_kwargs['batch'] = inputs.get('batch', None)
704
+ gen_kwargs['encoder_input_ids'] = inputs.get('encoder_input_ids', None)
705
+ gen_kwargs['decoder_input_ids'] = inputs.get('decoder_input_ids', None)[:,0:1]
706
+ gen_kwargs["decoder_attention_mask"] = torch.ones(gen_kwargs['decoder_input_ids'].shape[0], 1).to(self.args.device)
707
+
708
+ generated_tokens = self.model.generate(
709
+ generation_inputs,
710
+ **gen_kwargs,
711
+ )
712
+ # in case the batch is shorter than max length, the output should be padded
713
+ if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
714
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
715
+ elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
716
+ gen_kwargs["max_new_tokens"] + 1
717
+ ):
718
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
719
+
720
+ with torch.no_grad():
721
+ if has_labels:
722
+ with self.compute_loss_context_manager():
723
+ outputs = model(**inputs)
724
+ if self.label_smoother is not None:
725
+ loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
726
+ else:
727
+ loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
728
+ else:
729
+ loss = None
730
+
731
+ if self.args.prediction_loss_only:
732
+ return (loss, None, None)
733
+
734
+ if has_labels:
735
+ labels = inputs["labels"]
736
+ if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
737
+ labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
738
+ elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
739
+ gen_kwargs["max_new_tokens"] + 1
740
+ ):
741
+ labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
742
+ else:
743
+ labels = None
744
+
745
+ return (loss, generated_tokens, labels)
utils_convert.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from biopandas.pdb import PandasPdb
3
+
4
+ pdb_order = [
5
+ "record_name",
6
+ "atom_number",
7
+ "blank_1",
8
+ "atom_name",
9
+ "alt_loc",
10
+ "residue_name",
11
+ "blank_2",
12
+ "chain_id",
13
+ "residue_number",
14
+ "insertion",
15
+ "blank_3",
16
+ "x_coord",
17
+ "y_coord",
18
+ "z_coord",
19
+ "occupancy",
20
+ "b_factor",
21
+ "blank_4",
22
+ "segment_id",
23
+ "element_symbol",
24
+ "charge",
25
+ "line_idx",
26
+ ]
27
+ mmcif_read = {
28
+ "group_PDB": "record_name",
29
+ "id": "atom_number",
30
+ "auth_atom_id": "atom_name",
31
+ "auth_comp_id": "residue_name",
32
+ "auth_asym_id": "chain_id",
33
+ "auth_seq_id": "residue_number",
34
+ "Cartn_x": "x_coord",
35
+ "Cartn_y": "y_coord",
36
+ "Cartn_z": "z_coord",
37
+ "occupancy": "occupancy",
38
+ "B_iso_or_equiv": "b_factor",
39
+ "type_symbol": "element_symbol",
40
+ }
41
+
42
+ nonefields = [
43
+ "blank_1",
44
+ "alt_loc",
45
+ "blank_2",
46
+ "insertion",
47
+ "blank_3",
48
+ "blank_4",
49
+ "segment_id",
50
+ "charge",
51
+ "line_idx",
52
+ ]
53
+
54
+
55
+ def biopandas_mmcif2pdb(pandasmmcif, model_index = 1):
56
+ """
57
+ Converts the ATOM and HETATM dataframes of PandasMmcif() to PandasPdb() format.
58
+ """
59
+ pandaspdb = PandasPdb()
60
+ for a in ["ATOM", "HETATM"]:
61
+ dfa = pandasmmcif.df[a]
62
+ dfa = dfa.loc[dfa.pdbx_PDB_model_num == model_index]
63
+ if a =='ATOM':
64
+ if len(dfa) == 0:
65
+ raise ValueError(f"No model found for index: {model_index}")
66
+ # keep only those fields found in pdb
67
+ dfa = dfa[mmcif_read.keys()]
68
+ # rename fields
69
+ dfa = dfa.rename(columns=mmcif_read)
70
+ # add empty fields
71
+ for i in nonefields:
72
+ dfa[i] = ""
73
+ dfa["charge"] = np.nan
74
+ # reorder columns to PandasPdb order
75
+ dfa = dfa[pdb_order]
76
+ pandaspdb.df[a] = dfa
77
+
78
+ # update line_idx
79
+ pandaspdb.df["ATOM"]["line_idx"] = pandaspdb.df["ATOM"].index.values
80
+ pandaspdb.df["HETATM"]["line_idx"] = pandaspdb.df["HETATM"].index
81
+
82
+ return pandaspdb
utils_dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import csv
3
+
4
+ def load_GO_annot(filename):
5
+ # Load GO annotations
6
+ onts = ['mf', 'bp', 'cc']
7
+ prot2annot = {}
8
+ goterms = {ont: [] for ont in onts}
9
+ gonames = {ont: [] for ont in onts}
10
+ with open(filename, mode='r') as tsvfile:
11
+ reader = csv.reader(tsvfile, delimiter='\t')
12
+
13
+ # molecular function
14
+ next(reader, None) # skip the headers
15
+ goterms[onts[0]] = next(reader)
16
+ next(reader, None) # skip the headers
17
+ gonames[onts[0]] = next(reader)
18
+
19
+ # biological process
20
+ next(reader, None) # skip the headers
21
+ goterms[onts[1]] = next(reader)
22
+ next(reader, None) # skip the headers
23
+ gonames[onts[1]] = next(reader)
24
+
25
+ # cellular component
26
+ next(reader, None) # skip the headers
27
+ goterms[onts[2]] = next(reader)
28
+ next(reader, None) # skip the headers
29
+ gonames[onts[2]] = next(reader)
30
+
31
+ next(reader, None) # skip the headers
32
+ counts = {ont: np.zeros(len(goterms[ont]), dtype=float) for ont in onts}
33
+ for row in reader:
34
+ prot, prot_goterms = row[0], row[1:]
35
+ prot2annot[prot] = {ont: [] for ont in onts}
36
+ for i in range(3):
37
+ goterm_indices = [goterms[onts[i]].index(goterm) for goterm in prot_goterms[i].split(',') if goterm != '']
38
+ prot2annot[prot][onts[i]] = np.zeros(len(goterms[onts[i]]))
39
+ prot2annot[prot][onts[i]][goterm_indices] = 1.0
40
+ counts[onts[i]][goterm_indices] += 1.0
41
+ return prot2annot, goterms, gonames, counts
42
+
43
+
44
+ def load_EC_annot(filename):
45
+ # Load EC annotations """
46
+ prot2annot = {}
47
+ with open(filename, mode='r') as tsvfile:
48
+ reader = csv.reader(tsvfile, delimiter='\t')
49
+
50
+ # molecular function
51
+ next(reader, None) # skip the headers
52
+ ec_numbers = {'ec': next(reader)}
53
+ next(reader, None) # skip the headers
54
+ counts = {'ec': np.zeros(len(ec_numbers['ec']), dtype=float)}
55
+ for row in reader:
56
+ prot, prot_ec_numbers = row[0], row[1]
57
+ ec_indices = [ec_numbers['ec'].index(ec_num) for ec_num in prot_ec_numbers.split(',')]
58
+ prot2annot[prot] = {'ec': np.zeros(len(ec_numbers['ec']), dtype=np.int64)}
59
+ prot2annot[prot]['ec'][ec_indices] = 1.0
60
+ counts['ec'][ec_indices] += 1