Huhujingjing commited on
Commit
3717306
·
1 Parent(s): 9462766

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_transmxm.py +48 -0
  2. modeling_transmxm.py +1291 -0
configuration_transmxm.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class TransmxmConfig(PretrainedConfig):
6
+ model_type = "transmxm"
7
+
8
+ def __init__(
9
+ self,
10
+ dim: int = 128,
11
+ n_layer: int = 6,
12
+ cutoff: float = 5.0,
13
+ num_spherical: int = 7,
14
+ num_radial: int = 6,
15
+ envelope_exponent: int = 5,
16
+
17
+ smiles: List[str] = None,
18
+ processor_class: str = "SmilesProcessor",
19
+ **kwargs,
20
+ ):
21
+ self.dim = dim # the dimension of input feature
22
+ self.n_layer = n_layer # the number of GCN layers
23
+ self.cutoff = cutoff # the cutoff distance for neighbor searching
24
+ self.num_spherical = num_spherical # the number of spherical harmonics
25
+ self.num_radial = num_radial # the number of radial basis
26
+ self.envelope_exponent = envelope_exponent # the envelope exponent
27
+
28
+ self.smiles = smiles # process smiles
29
+ self.processor_class = processor_class
30
+
31
+
32
+ super().__init__(**kwargs)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ transmxm_config = TransmxmConfig(
37
+ dim=128,
38
+ n_layer=6,
39
+ cutoff=5.0,
40
+ num_spherical=7,
41
+ num_radial=6,
42
+ envelope_exponent=5,
43
+ smiles=["C", "CC", "CCC"],
44
+ processor_class="SmilesProcessor"
45
+ )
46
+ transmxm_config.save_pretrained("custom-transmxm")
47
+
48
+
modeling_transmxm.py ADDED
@@ -0,0 +1,1291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
+ from transmxm_model.configuration_transmxm import TransmxmConfig
4
+ import torch
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import Parameter, Sequential, ModuleList, Linear
11
+
12
+ from rdkit import Chem
13
+ from rdkit.Chem import AllChem
14
+
15
+ from transformers import PretrainedConfig
16
+ from transformers import PreTrainedModel
17
+ from transformers import AutoModel
18
+
19
+ from torch_geometric.data import Data
20
+ from torch_geometric.loader import DataLoader
21
+ from torch_geometric.utils import remove_self_loops, add_self_loops, sort_edge_index
22
+ from torch_scatter import scatter
23
+ from torch_geometric.nn import global_add_pool, radius
24
+ from torch_sparse import SparseTensor
25
+
26
+ from transmxm_model.configuration_transmxm import TransmxmConfig
27
+
28
+ from tqdm import tqdm
29
+ import numpy as np
30
+ import pandas as pd
31
+ from typing import List
32
+ import math
33
+ import inspect
34
+ from operator import itemgetter
35
+ from collections import OrderedDict
36
+ from math import sqrt, pi as PI
37
+ from scipy.optimize import brentq
38
+ from scipy import special as sp
39
+
40
+ try:
41
+ import sympy as sym
42
+ except ImportError:
43
+ sym = None
44
+
45
+
46
+
47
+ class SmilesDataset(torch.utils.data.Dataset):
48
+ def __init__(self, smiles):
49
+ self.smiles_list = smiles
50
+ self.data_list = []
51
+
52
+
53
+ def __len__(self):
54
+ return len(self.data_list)
55
+
56
+ def __getitem__(self, idx):
57
+ return self.data_list[idx]
58
+
59
+ def get_data(self, smiles):
60
+ self.smiles_list = smiles
61
+ # self.data_list = []
62
+ # bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
63
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'S': 4}
64
+
65
+ for i in range(len(self.smiles_list)):
66
+ # 将 SMILES 表示转换为 RDKit 的分子对象
67
+ # print(self.smiles_list[i])
68
+ mol = Chem.MolFromSmiles(self.smiles_list[i]) # 从smiles编码中获取结构信息
69
+ if mol is None:
70
+ print("无法创建Mol对象", self.smiles_list[i])
71
+ else:
72
+
73
+ mol3d = Chem.AddHs(
74
+ mol) # 在rdkit中,分子在默认情况下是不显示氢的,但氢原子对于真实的几何构象计算有很大的影响,所以在计算3D构象前,需要使用Chem.AddHs()方法加上氢原子
75
+ if mol3d is None:
76
+ print("无法创建mol3d对象", self.smiles_list[i])
77
+ else:
78
+ AllChem.EmbedMolecule(mol3d, randomSeed=1) # 生成3D构象
79
+
80
+ N = mol3d.GetNumAtoms()
81
+ # 获取原子坐标信息
82
+ if mol3d.GetNumConformers() > 0:
83
+ conformer = mol3d.GetConformer()
84
+ pos = conformer.GetPositions()
85
+ pos = torch.tensor(pos, dtype=torch.float)
86
+
87
+ type_idx = []
88
+ # atomic_number = []
89
+ # aromatic = []
90
+ # sp = []
91
+ # sp2 = []
92
+ # sp3 = []
93
+ for atom in mol3d.GetAtoms():
94
+ type_idx.append(types[atom.GetSymbol()])
95
+ # atomic_number.append(atom.GetAtomicNum())
96
+ # aromatic.append(1 if atom.GetIsAromatic() else 0)
97
+ # hybridization = atom.GetHybridization()
98
+ # sp.append(1 if hybridization == HybridizationType.SP else 0)
99
+ # sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
100
+ # sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
101
+
102
+ # z = torch.tensor(atomic_number, dtype=torch.long)
103
+
104
+ row, col, edge_type = [], [], []
105
+ for bond in mol3d.GetBonds():
106
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
107
+ row += [start, end]
108
+ col += [end, start]
109
+ # edge_type += 2 * [bonds[bond.GetBondType()]]
110
+
111
+ edge_index = torch.tensor([row, col], dtype=torch.long)
112
+ # edge_type = torch.tensor(edge_type, dtype=torch.long)
113
+ # edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)
114
+
115
+ perm = (edge_index[0] * N + edge_index[1]).argsort()
116
+ edge_index = edge_index[:, perm]
117
+ # edge_type = edge_type[perm]
118
+ # edge_attr = edge_attr[perm]
119
+ #
120
+ # row, col = edge_index
121
+ # hs = (z == 1).to(torch.float)
122
+
123
+ x = torch.tensor(type_idx).to(torch.float)
124
+
125
+ # y = self.y_list[i]
126
+
127
+ data = Data(x=x, pos=pos, edge_index=edge_index, smiles=self.smiles_list[i])
128
+
129
+ self.data_list.append(data)
130
+ else:
131
+ print("无法创建comfor", self.smiles_list[i])
132
+ return self.data_list
133
+
134
+
135
+ # --------------------------------------------------------
136
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
137
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
138
+ # Copyright (c) 2021 Microsoft
139
+ # Licensed under The MIT License [see LICENSE for details]
140
+ # Based on fairseq code bases
141
+ # https://github.com/pytorch/fairseq
142
+ # --------------------------------------------------------
143
+ import math
144
+ import logging
145
+ from typing import List, Optional, Tuple
146
+
147
+ import numpy as np
148
+ from torch.nn import LayerNorm
149
+ import copy
150
+ from typing import Optional
151
+
152
+ import torch
153
+ import torch.nn.functional as F
154
+ from torch import nn, Tensor
155
+
156
+
157
+ class PositionEmbeddingSine(nn.Module):
158
+ """
159
+ This is a more standard version of the position embedding, very similar to the one
160
+ used by the Attention is all you need paper, generalized to work on images. (To 1D sequences)
161
+ """
162
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
163
+ super().__init__()
164
+ self.num_pos_feats = num_pos_feats
165
+ self.temperature = temperature
166
+ self.normalize = normalize
167
+ if scale is not None and normalize is False:
168
+ raise ValueError("normalize should be True if scale is passed")
169
+ if scale is None:
170
+ scale = 2 * math.pi
171
+ self.scale = scale
172
+
173
+ def forward(self, x, mask):
174
+ """
175
+ Args:
176
+ x: torch.tensor, (batch_size, L, d)
177
+ mask: torch.tensor, (batch_size, L), with 1 as valid
178
+
179
+ Returns:
180
+
181
+ """
182
+ assert mask is not None
183
+ x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L)
184
+ if self.normalize:
185
+ eps = 1e-6
186
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
187
+
188
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
189
+ # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
190
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats)
191
+ pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats)
192
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2)
193
+ # import ipdb; ipdb.set_trace()
194
+ return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L)
195
+
196
+ def build_position_encoding(x):
197
+ N_steps = x
198
+ pos_embed = PositionEmbeddingSine(N_steps, normalize=True)
199
+
200
+ return pos_embed
201
+
202
+
203
+ class Transformer(nn.Module):
204
+
205
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
206
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
207
+ activation="relu", normalize_before=False):
208
+ super().__init__()
209
+
210
+ # TransformerEncoderLayer
211
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
212
+ dropout, activation, normalize_before)
213
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
214
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
215
+
216
+ self._reset_parameters()
217
+
218
+ self.d_model = d_model
219
+ self.nhead = nhead
220
+
221
+ def _reset_parameters(self):
222
+ for p in self.parameters():
223
+ if p.dim() > 1:
224
+ nn.init.xavier_uniform_(p)
225
+
226
+ def forward(self, src, mask, att_mask, pos_embed):
227
+ """
228
+ Args:
229
+ src: (batch_size, L, d)
230
+ mask: (batch_size, L)
231
+ query_embed: (#queries, d)
232
+ pos_embed: (batch_size, L, d) the same as src
233
+
234
+ Returns:
235
+
236
+ """
237
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
238
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
239
+
240
+ memory = self.encoder(
241
+ src,
242
+ mask=att_mask,
243
+ src_key_padding_mask=mask,
244
+ pos=pos_embed
245
+ )
246
+
247
+ memory = memory.transpose(0, 1)
248
+ return memory
249
+
250
+
251
+ class TransformerEncoder(nn.Module):
252
+
253
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
254
+ super().__init__()
255
+ self.layers = _get_clones(encoder_layer, num_layers)
256
+ self.num_layers = num_layers
257
+ self.norm = norm
258
+ self.return_intermediate = return_intermediate
259
+
260
+ def forward(self, src,
261
+ mask: Optional[Tensor] = None,
262
+ src_key_padding_mask: Optional[Tensor] = None,
263
+ pos: Optional[Tensor] = None):
264
+ output = src
265
+
266
+ intermediate = []
267
+
268
+ for layer in self.layers:
269
+ output = layer(output, src_mask=mask,
270
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
271
+ if self.return_intermediate:
272
+ intermediate.append(output)
273
+
274
+ if self.norm is not None:
275
+ output = self.norm(output)
276
+
277
+ if self.return_intermediate:
278
+ return torch.stack(intermediate)
279
+
280
+ return output
281
+
282
+
283
+ class TransformerEncoderLayer(nn.Module):
284
+
285
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
286
+ activation="relu", normalize_before=False):
287
+ super().__init__()
288
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
289
+ # Implementation of Feedforward model
290
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
291
+ self.dropout = nn.Dropout(dropout)
292
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
293
+
294
+ self.norm1 = nn.LayerNorm(d_model)
295
+ self.norm2 = nn.LayerNorm(d_model)
296
+ self.dropout1 = nn.Dropout(dropout)
297
+ self.dropout2 = nn.Dropout(dropout)
298
+
299
+ self.activation = _get_activation_fn(activation)
300
+ self.normalize_before = normalize_before
301
+
302
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
303
+ return tensor if pos is None else tensor + pos
304
+
305
+ def forward_post(self,
306
+ src,
307
+ src_mask: Optional[Tensor] = None,
308
+ src_key_padding_mask: Optional[Tensor] = None,
309
+ pos: Optional[Tensor] = None):
310
+ q = k = self.with_pos_embed(src, pos)
311
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
312
+ key_padding_mask=src_key_padding_mask)[0]
313
+ src = src + self.dropout1(src2)
314
+ src = self.norm1(src)
315
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
316
+ src = src + self.dropout2(src2)
317
+ src = self.norm2(src)
318
+ return src
319
+
320
+ def forward_pre(self, src,
321
+ src_mask: Optional[Tensor] = None,
322
+ src_key_padding_mask: Optional[Tensor] = None,
323
+ pos: Optional[Tensor] = None):
324
+ src2 = self.norm1(src)
325
+ q = k = self.with_pos_embed(src2, pos)
326
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
327
+ key_padding_mask=src_key_padding_mask)[0]
328
+ src = src + self.dropout1(src2)
329
+ src2 = self.norm2(src)
330
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
331
+ src = src + self.dropout2(src2)
332
+ return src
333
+
334
+ def forward(self, src,
335
+ src_mask: Optional[Tensor] = None,
336
+ src_key_padding_mask: Optional[Tensor] = None,
337
+ pos: Optional[Tensor] = None):
338
+ if self.normalize_before:
339
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
340
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
341
+
342
+
343
+ def _get_clones(module, N):
344
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
345
+
346
+
347
+ def build_transformer(x):
348
+ return Transformer(
349
+ d_model=x,
350
+ dropout=0.5,
351
+ nhead=8,
352
+ dim_feedforward=1024,
353
+ num_encoder_layers=2,
354
+ normalize_before=True,
355
+ )
356
+
357
+
358
+ def _get_activation_fn(activation):
359
+ """Return an activation function given a string"""
360
+ if activation == "relu":
361
+ return F.relu
362
+ if activation == "gelu":
363
+ return F.gelu
364
+ if activation == "glu":
365
+ return F.glu
366
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
367
+
368
+
369
+
370
+ class EMA:
371
+ def __init__(self, model, decay):
372
+ self.decay = decay
373
+ self.shadow = {}
374
+ self.original = {}
375
+
376
+ # Register model parameters
377
+ for name, param in model.named_parameters():
378
+ if param.requires_grad:
379
+ self.shadow[name] = param.data.clone()
380
+
381
+ def __call__(self, model, num_updates=99999):
382
+ decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates))
383
+ for name, param in model.named_parameters():
384
+ if param.requires_grad:
385
+ assert name in self.shadow
386
+ new_average = \
387
+ (1.0 - decay) * param.data + decay * self.shadow[name]
388
+ self.shadow[name] = new_average.clone()
389
+
390
+ def assign(self, model):
391
+ for name, param in model.named_parameters():
392
+ if param.requires_grad:
393
+ assert name in self.shadow
394
+ self.original[name] = param.data.clone()
395
+ param.data = self.shadow[name]
396
+
397
+ def resume(self, model):
398
+ for name, param in model.named_parameters():
399
+ if param.requires_grad:
400
+ assert name in self.shadow
401
+ param.data = self.original[name]
402
+
403
+
404
+ def MLP(channels):
405
+ return Sequential(*[
406
+ Sequential(Linear(channels[i - 1], channels[i]), SiLU())
407
+ for i in range(1, len(channels))])
408
+
409
+
410
+ class Res(nn.Module):
411
+ def __init__(self, dim):
412
+ super(Res, self).__init__()
413
+
414
+ self.mlp = MLP([dim, dim, dim])
415
+
416
+ def forward(self, m):
417
+ m1 = self.mlp(m)
418
+ m_out = m1 + m
419
+ return m_out
420
+
421
+
422
+ def compute_idx(pos, edge_index):
423
+
424
+ pos_i = pos[edge_index[0]]
425
+ pos_j = pos[edge_index[1]]
426
+
427
+ d_ij = torch.norm(abs(pos_j - pos_i), dim=-1, keepdim=False).unsqueeze(-1) + 1e-5
428
+ v_ji = (pos_i - pos_j) / d_ij
429
+
430
+ unique, counts = torch.unique(edge_index[0], sorted=True, return_counts=True) #Get central values
431
+ full_index = torch.arange(0, edge_index[0].size()[0]).cuda().int() #init full index
432
+ #print('full_index', full_index)
433
+
434
+ #Compute 1
435
+ repeat = torch.repeat_interleave(counts, counts)
436
+ counts_repeat1 = torch.repeat_interleave(full_index, repeat) #0,...,0,1,...,1,...
437
+
438
+ #Compute 2
439
+ split = torch.split(full_index, counts.tolist()) #split full index
440
+ index2 = list(edge_index[0].data.cpu().numpy()) #get repeat index
441
+ counts_repeat2 = torch.cat(itemgetter(*index2)(split), dim=0) #0,1,2,...,0,1,2,..
442
+
443
+ #Compute angle embeddings
444
+ v1 = v_ji[counts_repeat1.long()]
445
+ v2 = v_ji[counts_repeat2.long()]
446
+
447
+ angle = (v1*v2).sum(-1).unsqueeze(-1)
448
+ angle = torch.clamp(angle, min=-1.0, max=1.0) + 1e-6 + 1.0
449
+
450
+ return counts_repeat1.long(), counts_repeat2.long(), angle
451
+
452
+
453
+ def Jn(r, n):
454
+ return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r)
455
+
456
+
457
+ def Jn_zeros(n, k):
458
+ zerosj = np.zeros((n, k), dtype='float32')
459
+ zerosj[0] = np.arange(1, k + 1) * np.pi
460
+ points = np.arange(1, k + n) * np.pi
461
+ racines = np.zeros(k + n - 1, dtype='float32')
462
+ for i in range(1, n):
463
+ for j in range(k + n - 1 - i):
464
+ foo = brentq(Jn, points[j], points[j + 1], (i, ))
465
+ racines[j] = foo
466
+ points = racines
467
+ zerosj[i][:k] = racines[:k]
468
+
469
+ return zerosj
470
+
471
+
472
+ def spherical_bessel_formulas(n):
473
+ x = sym.symbols('x')
474
+
475
+ f = [sym.sin(x) / x]
476
+ a = sym.sin(x) / x
477
+ for i in range(1, n):
478
+ b = sym.diff(a, x) / x
479
+ f += [sym.simplify(b * (-x)**i)]
480
+ a = sym.simplify(b)
481
+ return f
482
+
483
+
484
+ def bessel_basis(n, k):
485
+ zeros = Jn_zeros(n, k)
486
+ normalizer = []
487
+ for order in range(n):
488
+ normalizer_tmp = []
489
+ for i in range(k):
490
+ normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2]
491
+ normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5
492
+ normalizer += [normalizer_tmp]
493
+
494
+ f = spherical_bessel_formulas(n)
495
+ x = sym.symbols('x')
496
+ bess_basis = []
497
+ for order in range(n):
498
+ bess_basis_tmp = []
499
+ for i in range(k):
500
+ bess_basis_tmp += [
501
+ sym.simplify(normalizer[order][i] *
502
+ f[order].subs(x, zeros[order, i] * x))
503
+ ]
504
+ bess_basis += [bess_basis_tmp]
505
+ return bess_basis
506
+
507
+
508
+ def sph_harm_prefactor(k, m):
509
+ return ((2 * k + 1) * np.math.factorial(k - abs(m)) /
510
+ (4 * np.pi * np.math.factorial(k + abs(m))))**0.5
511
+
512
+
513
+ def associated_legendre_polynomials(k, zero_m_only=True):
514
+ z = sym.symbols('z')
515
+ P_l_m = [[0] * (j + 1) for j in range(k)]
516
+
517
+ P_l_m[0][0] = 1
518
+ if k > 0:
519
+ P_l_m[1][0] = z
520
+
521
+ for j in range(2, k):
522
+ P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] -
523
+ (j - 1) * P_l_m[j - 2][0]) / j)
524
+ if not zero_m_only:
525
+ for i in range(1, k):
526
+ P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
527
+ if i + 1 < k:
528
+ P_l_m[i + 1][i] = sym.simplify(
529
+ (2 * i + 1) * z * P_l_m[i][i])
530
+ for j in range(i + 2, k):
531
+ P_l_m[j][i] = sym.simplify(
532
+ ((2 * j - 1) * z * P_l_m[j - 1][i] -
533
+ (i + j - 1) * P_l_m[j - 2][i]) / (j - i))
534
+
535
+ return P_l_m
536
+
537
+
538
+ def real_sph_harm(k, zero_m_only=True, spherical_coordinates=True):
539
+ if not zero_m_only:
540
+ S_m = [0]
541
+ C_m = [1]
542
+ for i in range(1, k):
543
+ x = sym.symbols('x')
544
+ y = sym.symbols('y')
545
+ S_m += [x * S_m[i - 1] + y * C_m[i - 1]]
546
+ C_m += [x * C_m[i - 1] - y * S_m[i - 1]]
547
+
548
+ P_l_m = associated_legendre_polynomials(k, zero_m_only)
549
+ if spherical_coordinates:
550
+ theta = sym.symbols('theta')
551
+ z = sym.symbols('z')
552
+ for i in range(len(P_l_m)):
553
+ for j in range(len(P_l_m[i])):
554
+ if type(P_l_m[i][j]) != int:
555
+ P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
556
+ if not zero_m_only:
557
+ phi = sym.symbols('phi')
558
+ for i in range(len(S_m)):
559
+ S_m[i] = S_m[i].subs(x,
560
+ sym.sin(theta) * sym.cos(phi)).subs(
561
+ y,
562
+ sym.sin(theta) * sym.sin(phi))
563
+ for i in range(len(C_m)):
564
+ C_m[i] = C_m[i].subs(x,
565
+ sym.sin(theta) * sym.cos(phi)).subs(
566
+ y,
567
+ sym.sin(theta) * sym.sin(phi))
568
+
569
+ Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)]
570
+ for i in range(k):
571
+ Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])
572
+
573
+ if not zero_m_only:
574
+ for i in range(1, k):
575
+ for j in range(1, i + 1):
576
+ Y_func_l_m[i][j] = sym.simplify(
577
+ 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j])
578
+ for i in range(1, k):
579
+ for j in range(1, i + 1):
580
+ Y_func_l_m[i][-j] = sym.simplify(
581
+ 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j])
582
+
583
+ return Y_func_l_m
584
+
585
+
586
+ class BesselBasisLayer(torch.nn.Module):
587
+ def __init__(self, num_radial, cutoff, envelope_exponent=6):
588
+ super(BesselBasisLayer, self).__init__()
589
+ self.cutoff = cutoff
590
+ self.envelope = Envelope(envelope_exponent)
591
+
592
+ self.freq = torch.nn.Parameter(torch.Tensor(num_radial))
593
+
594
+ self.reset_parameters()
595
+
596
+ def reset_parameters(self):
597
+ # 代替in-place操作
598
+ # torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
599
+ # self.freq = torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
600
+
601
+ # 计算临时张量并存储到 tmp_tensor 变量中
602
+ tmp_tensor = torch.arange(1, self.freq.numel() + 1, dtype=self.freq.dtype, device=self.freq.device)
603
+
604
+ # 使用乘法函数实现数乘并将结果保存到 self.freq 张量上
605
+ self.freq.data = torch.mul(tmp_tensor, PI)
606
+
607
+ def forward(self, dist):
608
+ dist = dist.unsqueeze(-1) / self.cutoff
609
+ return self.envelope(dist) * (self.freq * dist).sin()
610
+
611
+
612
+ class SiLU(nn.Module):
613
+ def __init__(self):
614
+ super().__init__()
615
+
616
+ def forward(self, input):
617
+ return silu(input)
618
+
619
+
620
+ def silu(input):
621
+ return input * torch.sigmoid(input)
622
+
623
+
624
+ class Envelope(torch.nn.Module):
625
+ def __init__(self, exponent):
626
+ super(Envelope, self).__init__()
627
+ self.p = exponent
628
+ self.a = -(self.p + 1) * (self.p + 2) / 2
629
+ self.b = self.p * (self.p + 2)
630
+ self.c = -self.p * (self.p + 1) / 2
631
+
632
+ def forward(self, x):
633
+ p, a, b, c = self.p, self.a, self.b, self.c
634
+ x_pow_p0 = x.pow(p)
635
+ x_pow_p1 = x_pow_p0 * x
636
+ env_val = 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p1 * x
637
+
638
+ zero = torch.zeros_like(x)
639
+ return torch.where(x < 1, env_val, zero)
640
+
641
+
642
+ class SphericalBasisLayer(torch.nn.Module):
643
+ def __init__(self, num_spherical, num_radial, cutoff=5.0,
644
+ envelope_exponent=5):
645
+ super(SphericalBasisLayer, self).__init__()
646
+ assert num_radial <= 64
647
+ self.num_spherical = num_spherical
648
+ self.num_radial = num_radial
649
+ self.cutoff = cutoff
650
+ self.envelope = Envelope(envelope_exponent)
651
+
652
+ bessel_forms = bessel_basis(num_spherical, num_radial)
653
+ sph_harm_forms = real_sph_harm(num_spherical)
654
+ self.sph_funcs = []
655
+ self.bessel_funcs = []
656
+
657
+ x, theta = sym.symbols('x theta')
658
+ modules = {'sin': torch.sin, 'cos': torch.cos}
659
+ for i in range(num_spherical):
660
+ if i == 0:
661
+ sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
662
+ self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1)
663
+ else:
664
+ sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
665
+ self.sph_funcs.append(sph)
666
+ for j in range(num_radial):
667
+ bessel = sym.lambdify([x], bessel_forms[i][j], modules)
668
+ self.bessel_funcs.append(bessel)
669
+
670
+ def forward(self, dist, angle, idx_kj):
671
+ dist = dist / self.cutoff
672
+ rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
673
+ rbf = self.envelope(dist).unsqueeze(-1) * rbf
674
+
675
+ cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)
676
+
677
+ n, k = self.num_spherical, self.num_radial
678
+ out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k)
679
+ return out
680
+
681
+
682
+
683
+ msg_special_args = set([
684
+ 'edge_index',
685
+ 'edge_index_i',
686
+ 'edge_index_j',
687
+ 'size',
688
+ 'size_i',
689
+ 'size_j',
690
+ ])
691
+
692
+ aggr_special_args = set([
693
+ 'index',
694
+ 'dim_size',
695
+ ])
696
+
697
+ update_special_args = set([])
698
+
699
+
700
+ class MessagePassing(torch.nn.Module):
701
+ r"""Base class for creating message passing layers
702
+
703
+ .. math::
704
+ \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
705
+ \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
706
+ \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
707
+
708
+ where :math:`\square` denotes a differentiable, permutation invariant
709
+ function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
710
+ and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
711
+ MLPs.
712
+ See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
713
+ create_gnn.html>`__ for the accompanying tutorial.
714
+
715
+ Args:
716
+ aggr (string, optional): The aggregation scheme to use
717
+ (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
718
+ (default: :obj:`"add"`)
719
+ flow (string, optional): The flow direction of message passing
720
+ (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
721
+ (default: :obj:`"source_to_target"`)
722
+ node_dim (int, optional): The axis along which to propagate.
723
+ (default: :obj:`0`)
724
+ """
725
+ def __init__(self, aggr='add', flow='target_to_source', node_dim=0):
726
+ super(MessagePassing, self).__init__()
727
+
728
+ self.aggr = aggr
729
+ assert self.aggr in ['add', 'mean', 'max']
730
+
731
+ self.flow = flow
732
+ assert self.flow in ['source_to_target', 'target_to_source']
733
+
734
+ self.node_dim = node_dim
735
+ assert self.node_dim >= 0
736
+
737
+ self.__msg_params__ = inspect.signature(self.message).parameters
738
+ self.__msg_params__ = OrderedDict(self.__msg_params__)
739
+
740
+ self.__aggr_params__ = inspect.signature(self.aggregate).parameters
741
+ self.__aggr_params__ = OrderedDict(self.__aggr_params__)
742
+ self.__aggr_params__.popitem(last=False)
743
+
744
+ self.__update_params__ = inspect.signature(self.update).parameters
745
+ self.__update_params__ = OrderedDict(self.__update_params__)
746
+ self.__update_params__.popitem(last=False)
747
+
748
+ msg_args = set(self.__msg_params__.keys()) - msg_special_args
749
+ aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args
750
+ update_args = set(self.__update_params__.keys()) - update_special_args
751
+
752
+ self.__args__ = set().union(msg_args, aggr_args, update_args)
753
+
754
+ def __set_size__(self, size, index, tensor):
755
+ if not torch.is_tensor(tensor):
756
+ pass
757
+ elif size[index] is None:
758
+ size[index] = tensor.size(self.node_dim)
759
+ elif size[index] != tensor.size(self.node_dim):
760
+ raise ValueError(
761
+ (f'Encountered node tensor with size '
762
+ f'{tensor.size(self.node_dim)} in dimension {self.node_dim}, '
763
+ f'but expected size {size[index]}.'))
764
+
765
+ def __collect__(self, edge_index, size, kwargs):
766
+ i, j = (0, 1) if self.flow == "target_to_source" else (1, 0)
767
+ ij = {"_i": i, "_j": j}
768
+
769
+ out = {}
770
+ for arg in self.__args__:
771
+ if arg[-2:] not in ij.keys():
772
+ out[arg] = kwargs.get(arg, inspect.Parameter.empty)
773
+ else:
774
+ idx = ij[arg[-2:]]
775
+ data = kwargs.get(arg[:-2], inspect.Parameter.empty)
776
+
777
+ if data is inspect.Parameter.empty:
778
+ out[arg] = data
779
+ continue
780
+
781
+ if isinstance(data, tuple) or isinstance(data, list):
782
+ assert len(data) == 2
783
+ self.__set_size__(size, 1 - idx, data[1 - idx])
784
+ data = data[idx]
785
+
786
+ if not torch.is_tensor(data):
787
+ out[arg] = data
788
+ continue
789
+
790
+ self.__set_size__(size, idx, data)
791
+ out[arg] = data.index_select(self.node_dim, edge_index[idx])
792
+
793
+ size[0] = size[1] if size[0] is None else size[0]
794
+ size[1] = size[0] if size[1] is None else size[1]
795
+
796
+ # Add special message arguments.
797
+ out['edge_index'] = edge_index
798
+ out['edge_index_i'] = edge_index[i]
799
+ out['edge_index_j'] = edge_index[j]
800
+ out['size'] = size
801
+ out['size_i'] = size[i]
802
+ out['size_j'] = size[j]
803
+
804
+ # Add special aggregate arguments.
805
+ out['index'] = out['edge_index_i']
806
+ out['dim_size'] = out['size_i']
807
+
808
+ return out
809
+
810
+ def __distribute__(self, params, kwargs):
811
+ out = {}
812
+ for key, param in params.items():
813
+ data = kwargs[key]
814
+ if data is inspect.Parameter.empty:
815
+ if param.default is inspect.Parameter.empty:
816
+ raise TypeError(f'Required parameter {key} is empty.')
817
+ data = param.default
818
+ out[key] = data
819
+ return out
820
+
821
+ def propagate(self, edge_index, size=None, **kwargs):
822
+ r"""The initial call to start propagating messages.
823
+
824
+ Args:
825
+ edge_index (Tensor): The indices of a general (sparse) assignment
826
+ matrix with shape :obj:`[N, M]` (can be directed or
827
+ undirected).
828
+ size (list or tuple, optional): The size :obj:`[N, M]` of the
829
+ assignment matrix. If set to :obj:`None`, the size will be
830
+ automatically inferred and assumed to be quadratic.
831
+ (default: :obj:`None`)
832
+ **kwargs: Any additional data which is needed to construct and
833
+ aggregate messages, and to update node embeddings.
834
+ """
835
+
836
+ size = [None, None] if size is None else size
837
+ size = [size, size] if isinstance(size, int) else size
838
+ size = size.tolist() if torch.is_tensor(size) else size
839
+ size = list(size) if isinstance(size, tuple) else size
840
+ assert isinstance(size, list)
841
+ assert len(size) == 2
842
+
843
+ kwargs = self.__collect__(edge_index, size, kwargs)
844
+
845
+ msg_kwargs = self.__distribute__(self.__msg_params__, kwargs)
846
+
847
+ m = self.message(**msg_kwargs)
848
+ aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs)
849
+ m = self.aggregate(m, **aggr_kwargs)
850
+
851
+ update_kwargs = self.__distribute__(self.__update_params__, kwargs)
852
+ m = self.update(m, **update_kwargs)
853
+
854
+ return m
855
+
856
+ def message(self, x_j): # pragma: no cover
857
+ r"""Constructs messages to node :math:`i` in analogy to
858
+ :math:`\phi_{\mathbf{\Theta}}` for each edge in
859
+ :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
860
+ :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
861
+ Can take any argument which was initially passed to :meth:`propagate`.
862
+ In addition, tensors passed to :meth:`propagate` can be mapped to the
863
+ respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
864
+ :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
865
+ """
866
+
867
+ return x_j
868
+
869
+ def aggregate(self, inputs, index, dim_size): # pragma: no cover
870
+ r"""Aggregates messages from neighbors as
871
+ :math:`\square_{j \in \mathcal{N}(i)}`.
872
+
873
+ By default, delegates call to scatter functions that support
874
+ "add", "mean" and "max" operations specified in :meth:`__init__` by
875
+ the :obj:`aggr` argument.
876
+ """
877
+
878
+ return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
879
+
880
+ def update(self, inputs): # pragma: no cover
881
+ r"""Updates node embeddings in analogy to
882
+ :math:`\gamma_{\mathbf{\Theta}}` for each node
883
+ :math:`i \in \mathcal{V}`.
884
+ Takes in the output of aggregation as first argument and any argument
885
+ which was initially passed to :meth:`propagate`.
886
+ """
887
+
888
+ return inputs
889
+
890
+ class TransMXMNet(nn.Module):
891
+ def __init__(self, dim=128, n_layer=6, cutoff=5.0, num_spherical=7, num_radial=6, envelope_exponent=5):
892
+ super(TransMXMNet, self).__init__()
893
+
894
+ self.dim = dim
895
+ self.n_layer = n_layer
896
+ self.cutoff = cutoff
897
+
898
+ self.embeddings = nn.Parameter(torch.ones((5, self.dim)))
899
+
900
+ self.rbf_l = BesselBasisLayer(16, 5, envelope_exponent)
901
+ self.rbf_g = BesselBasisLayer(16, self.cutoff, envelope_exponent)
902
+ self.sbf = SphericalBasisLayer(num_spherical, num_radial, 5, envelope_exponent)
903
+
904
+ self.rbf_g_mlp = MLP([16, self.dim])
905
+ self.rbf_l_mlp = MLP([16, self.dim])
906
+
907
+ self.sbf_1_mlp = MLP([num_spherical * num_radial, self.dim])
908
+ self.sbf_2_mlp = MLP([num_spherical * num_radial, self.dim])
909
+
910
+ self.global_layers = torch.nn.ModuleList()
911
+ for layer in range(self.n_layer):
912
+ self.global_layers.append(Global_MP(self.dim))
913
+
914
+ self.local_layers = torch.nn.ModuleList()
915
+ for layer in range(self.n_layer):
916
+ self.local_layers.append(Local_MP(self.dim))
917
+
918
+ self.pos_embed = build_position_encoding(self.dim)
919
+ self.transformer = build_transformer(self.dim)
920
+
921
+ self.init()
922
+
923
+ def init(self):
924
+ stdv = math.sqrt(3)
925
+ self.embeddings.data.uniform_(-stdv, stdv)
926
+
927
+ def indices(self, edge_index, num_nodes):
928
+ row, col = edge_index
929
+
930
+ value = torch.arange(row.size(0), device=row.device)
931
+ adj_t = SparseTensor(row=col, col=row, value=value,
932
+ sparse_sizes=(num_nodes, num_nodes))
933
+
934
+ #Compute the node indices for two-hop angles
935
+ adj_t_row = adj_t[row]
936
+ num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
937
+
938
+ idx_i = col.repeat_interleave(num_triplets)
939
+ idx_j = row.repeat_interleave(num_triplets)
940
+ idx_k = adj_t_row.storage.col()
941
+ mask = idx_i != idx_k
942
+ idx_i_1, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
943
+
944
+ idx_kj = adj_t_row.storage.value()[mask]
945
+ idx_ji_1 = adj_t_row.storage.row()[mask]
946
+
947
+ #Compute the node indices for one-hop angles
948
+ adj_t_col = adj_t[col]
949
+
950
+ num_pairs = adj_t_col.set_value(None).sum(dim=1).to(torch.long)
951
+ idx_i_2 = row.repeat_interleave(num_pairs)
952
+ idx_j1 = col.repeat_interleave(num_pairs)
953
+ idx_j2 = adj_t_col.storage.col()
954
+
955
+ idx_ji_2 = adj_t_col.storage.row()
956
+ idx_jj = adj_t_col.storage.value()
957
+
958
+ return idx_i_1, idx_j, idx_k, idx_kj, idx_ji_1, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2
959
+
960
+
961
+ def forward_features(self, data):
962
+ x = data.x
963
+ edge_index = data.edge_index
964
+ pos = data.pos
965
+ batch = data.batch
966
+ # Initialize node embeddings
967
+ h = torch.index_select(self.embeddings, 0, x.long()).unsqueeze(0)
968
+ data_len = torch.bincount(batch)
969
+ # 计算相邻元素差异
970
+ diff_tensor = torch.diff(data_len)
971
+ indices = torch.nonzero(diff_tensor) + 1
972
+ indices[0] = 0
973
+
974
+ att_mask = torch.zeros(len(batch), len(batch)).cuda()
975
+
976
+ att_mask[indices[0]:, indices[0]:] = 1
977
+ i = 0
978
+ for i in range(0, h.size(0) - 1):
979
+ att_mask[indices[i]:indices[i + 1], indices[i]:indices[i + 1]] = 1
980
+ att_mask[indices[i]:indices[-1], indices[i]:indices[-1]] = 1
981
+
982
+ mask = torch.ones(1, len(batch)).bool().cuda()
983
+
984
+ pos_h = self.pos_embed(h, mask).cuda()
985
+ memory = self.transformer(h, ~mask, att_mask, pos_h)
986
+ h = memory.squeeze(0)
987
+
988
+ '''局部层--------------------------------------------------------------------------
989
+ '''
990
+ # Get the edges and pairwise distances in the local layer
991
+ edge_index_l, _ = remove_self_loops(edge_index) # 移除自环后的边索引
992
+ j_l, i_l = edge_index_l
993
+ dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt() # 两个节点之间的距离
994
+
995
+ '''全局层--------------------------------------------------------------------------
996
+ '''
997
+ # Get the edges pairwise distances in the global layer
998
+ # radius函数返回两个节点之间的距离小于cutoff的边索引
999
+ row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500)
1000
+ edge_index_g = torch.stack([row, col], dim=0)
1001
+ edge_index_g, _ = remove_self_loops(edge_index_g)
1002
+ j_g, i_g = edge_index_g
1003
+ dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt()
1004
+
1005
+ # Compute the node indices for defining the angles
1006
+ idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0))
1007
+
1008
+ # Compute the two-hop angles
1009
+ pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j]
1010
+ a = (pos_ji_1 * pos_kj).sum(dim=-1)
1011
+ b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1)
1012
+ angle_1 = torch.atan2(b, a)
1013
+
1014
+ # Compute the one-hop angles
1015
+ pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1]
1016
+ a = (pos_ji_2 * pos_jj).sum(dim=-1)
1017
+ b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1)
1018
+ angle_2 = torch.atan2(b, a)
1019
+
1020
+ # Get the RBF and SBF embeddings
1021
+ rbf_g = self.rbf_g(dist_g)
1022
+ rbf_l = self.rbf_l(dist_l)
1023
+ sbf_1 = self.sbf(dist_l, angle_1, idx_kj)
1024
+ sbf_2 = self.sbf(dist_l, angle_2, idx_jj)
1025
+
1026
+ rbf_g = self.rbf_g_mlp(rbf_g)
1027
+ rbf_l = self.rbf_l_mlp(rbf_l)
1028
+ sbf_1 = self.sbf_1_mlp(sbf_1)
1029
+ sbf_2 = self.sbf_2_mlp(sbf_2)
1030
+
1031
+ # Perform the message passing schemes
1032
+ node_sum = 0
1033
+
1034
+ for layer in range(self.n_layer):
1035
+ h = self.global_layers[layer](h, rbf_g, edge_index_g)
1036
+ h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l)
1037
+ node_sum += t
1038
+
1039
+ # Readout
1040
+ output = global_add_pool(node_sum, batch)
1041
+ return output.view(-1)
1042
+
1043
+ def loss(self, pred, label):
1044
+ pred, label = pred.reshape(-1), label.reshape(-1)
1045
+ return F.mse_loss(pred, label)
1046
+
1047
+
1048
+ class Global_MP(MessagePassing):
1049
+
1050
+ def __init__(self, dim):
1051
+ super(Global_MP, self).__init__()
1052
+ self.dim = dim
1053
+
1054
+ self.h_mlp = MLP([self.dim, self.dim])
1055
+
1056
+ self.res1 = Res(self.dim)
1057
+ self.res2 = Res(self.dim)
1058
+ self.res3 = Res(self.dim)
1059
+ self.mlp = MLP([self.dim, self.dim])
1060
+
1061
+ self.x_edge_mlp = MLP([self.dim * 3, self.dim])
1062
+ self.linear = nn.Linear(self.dim, self.dim, bias=False)
1063
+
1064
+ def forward(self, h, edge_attr, edge_index):
1065
+ edge_index, _ = add_self_loops(edge_index, num_nodes=h.size(0))
1066
+
1067
+ res_h = h
1068
+
1069
+ # Integrate the Cross Layer Mapping inside the Global Message Passing
1070
+ h = self.h_mlp(h)
1071
+
1072
+ # Message Passing operation
1073
+ h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr)
1074
+
1075
+ # Update function f_u
1076
+ h = self.res1(h)
1077
+ h = self.mlp(h) + res_h
1078
+ h = self.res2(h)
1079
+ h = self.res3(h)
1080
+
1081
+ # Message Passing operation
1082
+ h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr)
1083
+
1084
+ return h
1085
+
1086
+ def message(self, x_i, x_j, edge_attr, edge_index, num_nodes):
1087
+ num_edge = edge_attr.size()[0]
1088
+
1089
+ x_edge = torch.cat((x_i[:num_edge], x_j[:num_edge], edge_attr), -1)
1090
+ x_edge = self.x_edge_mlp(x_edge)
1091
+
1092
+ x_j = torch.cat((self.linear(edge_attr) * x_edge, x_j[num_edge:]), dim=0)
1093
+
1094
+ return x_j
1095
+
1096
+ def update(self, aggr_out):
1097
+ return aggr_out
1098
+
1099
+
1100
+ class Local_MP(torch.nn.Module):
1101
+ def __init__(self, dim):
1102
+ super(Local_MP, self).__init__()
1103
+ self.dim = dim
1104
+
1105
+ self.h_mlp = MLP([self.dim, self.dim])
1106
+
1107
+ self.mlp_kj = MLP([3 * self.dim, self.dim])
1108
+ self.mlp_ji_1 = MLP([3 * self.dim, self.dim])
1109
+ self.mlp_ji_2 = MLP([self.dim, self.dim])
1110
+ self.mlp_jj = MLP([self.dim, self.dim])
1111
+
1112
+ self.mlp_sbf1 = MLP([self.dim, self.dim, self.dim])
1113
+ self.mlp_sbf2 = MLP([self.dim, self.dim, self.dim])
1114
+ self.lin_rbf1 = nn.Linear(self.dim, self.dim, bias=False)
1115
+ self.lin_rbf2 = nn.Linear(self.dim, self.dim, bias=False)
1116
+
1117
+ self.res1 = Res(self.dim)
1118
+ self.res2 = Res(self.dim)
1119
+ self.res3 = Res(self.dim)
1120
+
1121
+ self.lin_rbf_out = nn.Linear(self.dim, self.dim, bias=False)
1122
+
1123
+ self.h_mlp = MLP([self.dim, self.dim])
1124
+
1125
+ self.y_mlp = MLP([self.dim, self.dim, self.dim, self.dim])
1126
+ self.y_W = nn.Linear(self.dim, 1)
1127
+
1128
+ def forward(self, h, rbf, sbf1, sbf2, idx_kj, idx_ji_1, idx_jj, idx_ji_2, edge_index, num_nodes=None):
1129
+ res_h = h
1130
+
1131
+ # Integrate the Cross Layer Mapping inside the Local Message Passing
1132
+ h = self.h_mlp(h)
1133
+
1134
+ # Message Passing 1
1135
+ j, i = edge_index
1136
+ m = torch.cat([h[i], h[j], rbf], dim=-1)
1137
+
1138
+ m_kj = self.mlp_kj(m)
1139
+ m_kj = m_kj * self.lin_rbf1(rbf)
1140
+ m_kj = m_kj[idx_kj] * self.mlp_sbf1(sbf1)
1141
+ m_kj = scatter(m_kj, idx_ji_1, dim=0, dim_size=m.size(0), reduce='add')
1142
+
1143
+ m_ji_1 = self.mlp_ji_1(m)
1144
+
1145
+ m = m_ji_1 + m_kj
1146
+
1147
+ # Message Passing 2 (index jj denotes j'i in the main paper)
1148
+ m_jj = self.mlp_jj(m)
1149
+ m_jj = m_jj * self.lin_rbf2(rbf)
1150
+ m_jj = m_jj[idx_jj] * self.mlp_sbf2(sbf2)
1151
+ m_jj = scatter(m_jj, idx_ji_2, dim=0, dim_size=m.size(0), reduce='add')
1152
+
1153
+ m_ji_2 = self.mlp_ji_2(m)
1154
+
1155
+ m = m_ji_2 + m_jj
1156
+
1157
+ # Aggregation
1158
+ m = self.lin_rbf_out(rbf) * m
1159
+ h = scatter(m, i, dim=0, dim_size=h.size(0), reduce='add')
1160
+
1161
+ # Update function f_u
1162
+ h = self.res1(h)
1163
+ h = self.h_mlp(h) + res_h
1164
+ h = self.res2(h)
1165
+ h = self.res3(h)
1166
+
1167
+ # Output Module
1168
+ y = self.y_mlp(h)
1169
+ y = self.y_W(y)
1170
+
1171
+ return h, y
1172
+
1173
+
1174
+ # class MXMConfig(PretrainedConfig):
1175
+ # model_type = "gcn"
1176
+ #
1177
+ # def __init__(
1178
+ # self,
1179
+ # dim: int=128,
1180
+ # n_layer: int=6,
1181
+ # cutoff: float=5.0,
1182
+ # num_spherical: int=7,
1183
+ # num_radial: int=6,
1184
+ # envelope_exponent: int=5,
1185
+ #
1186
+ # smiles: List[str] = None,
1187
+ # processor_class: str = "SmilesProcessor",
1188
+ # **kwargs,
1189
+ # ):
1190
+ #
1191
+ # self.dim = dim # the dimension of input feature
1192
+ # self.n_layer = n_layer # the number of GCN layers
1193
+ # self.cutoff = cutoff # the cutoff distance for neighbor searching
1194
+ # self.num_spherical = num_spherical # the number of spherical harmonics
1195
+ # self.num_radial = num_radial # the number of radial basis
1196
+ # self.envelope_exponent = envelope_exponent # the envelope exponent
1197
+ #
1198
+ # self.smiles = smiles # process smiles
1199
+ # self.processor_class = processor_class
1200
+ #
1201
+ #
1202
+ # super().__init__(**kwargs)
1203
+
1204
+
1205
+
1206
+ class TransmxmModel(PreTrainedModel):
1207
+ config_class = TransmxmConfig
1208
+
1209
+ def __init__(self, config):
1210
+ super().__init__(config)
1211
+
1212
+ self.model = TransMXMNet(
1213
+ dim=config.dim,
1214
+ n_layer=config.n_layer,
1215
+ cutoff=config.cutoff,
1216
+ num_spherical=config.num_spherical,
1217
+ num_radial=config.num_radial,
1218
+ envelope_exponent=config.envelope_exponent,
1219
+ )
1220
+ self.process = SmilesDataset(
1221
+ smiles=config.smiles,
1222
+ )
1223
+
1224
+ self.mxm_model = None
1225
+ self.dataset = None
1226
+ self.output = None
1227
+ self.data_loader = None
1228
+ self.pred_data = None
1229
+
1230
+ def forward(self, tensor):
1231
+ return self.model.forward_features(tensor)
1232
+
1233
+ def SmilesProcessor(self, smiles):
1234
+ return self.process.get_data(smiles)
1235
+
1236
+
1237
+ def predict_smiles(self, smiles, device: str='cpu', result_dir: str='./', **kwargs):
1238
+
1239
+
1240
+ batch_size = kwargs.pop('batch_size', 1)
1241
+ shuffle = kwargs.pop('shuffle', False)
1242
+ drop_last = kwargs.pop('drop_last', False)
1243
+ num_workers = kwargs.pop('num_workers', 0)
1244
+
1245
+ self.mxm_model = AutoModel.from_pretrained("Huhujingjing/custom-transmxm", trust_remote_code=True).to(device)
1246
+ self.mxm_model.eval()
1247
+
1248
+ self.dataset = self.process.get_data(smiles)
1249
+ self.output = ""
1250
+ self.output += ("predicted samples num: {}\n".format(len(self.dataset)))
1251
+ self.output +=("predicted samples:{}\n".format(self.dataset[0]))
1252
+ self.data_loader = DataLoader(self.dataset,
1253
+ batch_size=batch_size,
1254
+ shuffle=shuffle,
1255
+ drop_last=drop_last,
1256
+ num_workers=num_workers
1257
+ )
1258
+ self.pred_data = {
1259
+ 'smiles': [],
1260
+ 'pred': []
1261
+ }
1262
+
1263
+ for batch in tqdm(self.data_loader):
1264
+ batch = batch.to(device)
1265
+ with torch.no_grad():
1266
+ self.pred_data['smiles'] += batch['smiles']
1267
+ self.pred_data['pred'] += self.gcn_model(batch).cpu().tolist()
1268
+
1269
+ pred = torch.tensor(self.pred_data['pred']).reshape(-1)
1270
+ if device == 'cuda':
1271
+ pred = pred.cpu().tolist()
1272
+ self.pred_data['pred'] = pred
1273
+ pred_df = pd.DataFrame(self.pred_data)
1274
+ pred_df['pred'] = pred_df['pred'].apply(lambda x: round(x, 2))
1275
+ self.output +=('-' * 40 + '\n'+'predicted result: \n'+'{}\n'.format(pred_df))
1276
+ self.output +=('-' * 40)
1277
+
1278
+ pred_df.to_csv(os.path.join(result_dir, 'prediction.csv'), index=False)
1279
+ self.output +=('\nsave predicted result to {}\n'.format(os.path.join(result_dir, 'prediction.csv')))
1280
+
1281
+ return self.output
1282
+
1283
+
1284
+ if __name__ == "__main__":
1285
+
1286
+ transmxm_config = TransmxmConfig.from_pretrained("custom-transmxm")
1287
+
1288
+ transmxmd = TransmxmModel(transmxm_config)
1289
+ transmxmd.model.load_state_dict(torch.load(r'G:\Trans_MXM\runs\model.pt'))
1290
+ transmxmd.save_pretrained("custom-transmxm")
1291
+