osbm commited on
Commit
6fec8c3
·
1 Parent(s): df223a5

Update new_dataloader.py

Browse files
Files changed (1) hide show
  1. new_dataloader.py +63 -117
new_dataloader.py CHANGED
@@ -1,14 +1,14 @@
1
  import pickle
2
- import os.path as osp
3
- import re
4
-
5
- import torch
6
  import numpy as np
7
- from tqdm import tqdm
8
  from rdkit import Chem
9
- from rdkit import RDLogger
10
  from torch_geometric.data import (Data, InMemoryDataset)
11
-
 
 
 
 
 
12
  RDLogger.DisableLog('rdApp.*')
13
  class DruggenDataset(InMemoryDataset):
14
 
@@ -18,64 +18,46 @@ class DruggenDataset(InMemoryDataset):
18
  self.raw_files = raw_files
19
  self.max_atom = max_atom
20
  self.features = features
21
-
22
  super().__init__(root, transform, pre_transform, pre_filter)
23
- self.data, self.slices = torch.load(osp.join(root, dataset_file))
24
-
 
25
 
 
 
 
 
 
 
26
  @property
27
  def raw_file_names(self):
28
  return self.raw_files
29
 
30
  @property
31
  def processed_file_names(self):
32
- '''
33
- Return the processed file names. If these names are not present, they will be automatically processed using process function of this class.
34
- '''
35
  return self.dataset_file
36
 
37
  def _generate_encoders_decoders(self, data):
38
- """
39
- Generates the encoders and decoders for the atoms and bonds.
40
- """
41
  self.data = data
42
  print('Creating atoms encoder and decoder..')
43
-
44
- atom_labels = set()
45
- # bond_labels = set()
46
- self.max_atom_size_in_data = 0
47
-
48
- for smile in data:
49
- mol = Chem.MolFromSmiles(smile)
50
- atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
51
- # bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
52
- self.max_atom_size_in_data = max(self.max_atom_size_in_data, mol.GetNumAtoms())
53
- atom_labels.update([0]) # add PAD symbol (for unknown atoms)
54
- atom_labels = sorted(atom_labels) # turn set into list and sort it
55
-
56
- # atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
57
  self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
58
  self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
59
  self.atom_num_types = len(atom_labels)
60
- print(f'Created atoms encoder and decoder with {self.atom_num_types - 1} atom types and 1 PAD symbol!')
 
61
  print("atom_labels", atom_labels)
62
  print('Creating bonds encoder and decoder..')
63
- # bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
64
- # for mol in self.data
65
- # for bond in mol.GetBonds())))
66
- bond_labels = [
67
- Chem.rdchem.BondType.ZERO,
68
- Chem.rdchem.BondType.SINGLE,
69
- Chem.rdchem.BondType.DOUBLE,
70
- Chem.rdchem.BondType.TRIPLE,
71
- Chem.rdchem.BondType.AROMATIC,
72
- ]
73
-
74
  print("bond labels", bond_labels)
75
  self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
76
  self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
77
  self.bond_num_types = len(bond_labels)
78
- print(f'Created bonds encoder and decoder with {self.bond_num_types - 1} bond types and 1 PAD symbol!')
 
79
  #dataset_names = str(self.dataset_name)
80
  with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
81
  pickle.dump(self.atom_encoder_m,atom_encoders)
@@ -94,19 +76,8 @@ class DruggenDataset(InMemoryDataset):
94
 
95
 
96
 
97
- def generate_adjacency_matrix(self, mol, connected=True, max_length=None):
98
- """
99
- Generates the adjacency matrix for a molecule.
100
-
101
- Args:
102
- mol (Molecule): The molecule object.
103
- connected (bool): Whether to check for connectivity in the molecule. Defaults to True.
104
- max_length (int): The maximum length of the adjacency matrix. Defaults to the number of atoms in the molecule.
105
 
106
- Returns:
107
- numpy.ndarray or None: The adjacency matrix if connected and all atoms have a degree greater than 0,
108
- otherwise None.
109
- """
110
  max_length = max_length if max_length is not None else mol.GetNumAtoms()
111
 
112
  A = np.zeros(shape=(max_length, max_length))
@@ -121,33 +92,15 @@ class DruggenDataset(InMemoryDataset):
121
 
122
  return A if connected and (degree > 0).all() else None
123
 
124
- def generate_node_features(self, mol, max_length=None):
125
- """
126
- Generates the node features for a molecule.
127
-
128
- Args:
129
- mol (Molecule): The molecule object.
130
- max_length (int): The maximum length of the node features. Defaults to the number of atoms in the molecule.
131
 
132
- Returns:
133
- numpy.ndarray: The node features matrix.
134
- """
135
  max_length = max_length if max_length is not None else mol.GetNumAtoms()
136
 
137
  return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
138
  max_length - mol.GetNumAtoms()))
139
 
140
- def generate_additional_features(self, mol, max_length=None):
141
- """
142
- Generates additional features for a molecule.
143
 
144
- Args:
145
- mol (Molecule): The molecule object.
146
- max_length (int): The maximum length of the additional features. Defaults to the number of atoms in the molecule.
147
-
148
- Returns:
149
- numpy.ndarray: The additional features matrix.
150
- """
151
  max_length = max_length if max_length is not None else mol.GetNumAtoms()
152
 
153
  features = np.array([[*[a.GetDegree() == i for i in range(5)],
@@ -164,19 +117,19 @@ class DruggenDataset(InMemoryDataset):
164
 
165
  return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
166
 
167
- def decoder_load(self, dictionary_name):
168
- with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
169
  return pickle.load(f)
170
 
171
  def drugs_decoder_load(self, dictionary_name):
172
  with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
173
  return pickle.load(f)
174
 
175
- def matrices2mol(self, node_labels, edge_labels, strict=True):
176
  mol = Chem.RWMol()
177
  RDLogger.DisableLog('rdApp.*')
178
- atom_decoders = self.decoder_load("atom")
179
- bond_decoders = self.decoder_load("bond")
180
 
181
  for node_label in node_labels:
182
  mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
@@ -184,7 +137,7 @@ class DruggenDataset(InMemoryDataset):
184
  for start, end in zip(*np.nonzero(edge_labels)):
185
  if start > end:
186
  mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
187
- mol = self.correct_mol(mol)
188
  if strict:
189
  try:
190
 
@@ -194,18 +147,18 @@ class DruggenDataset(InMemoryDataset):
194
 
195
  return mol
196
 
197
- def drug_decoder_load(self, dictionary_name):
198
 
199
  ''' Loading the atom and bond decoders '''
200
 
201
- with open("DrugGEN/data/decoders/" + dictionary_name +"_" + "akt_train" +'.pkl', 'rb') as f:
202
 
203
  return pickle.load(f)
204
- def matrices2mol_drugs(self, node_labels, edge_labels, strict=True):
205
  mol = Chem.RWMol()
206
  RDLogger.DisableLog('rdApp.*')
207
- atom_decoders = self.drug_decoder_load("atom")
208
- bond_decoders = self.drug_decoder_load("bond")
209
 
210
  for node_label in node_labels:
211
 
@@ -214,7 +167,7 @@ class DruggenDataset(InMemoryDataset):
214
  for start, end in zip(*np.nonzero(edge_labels)):
215
  if start > end:
216
  mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
217
- mol = self.correct_mol(mol)
218
  if strict:
219
  try:
220
  Chem.SanitizeMol(mol)
@@ -240,7 +193,7 @@ class DruggenDataset(InMemoryDataset):
240
 
241
 
242
  def correct_mol(self,x):
243
- # xsm = Chem.MolToSmiles(x, isomericSmiles=True)
244
  mol = x
245
  while True:
246
  flag, atomid_valence = self.check_valency(mol)
@@ -284,41 +237,34 @@ class DruggenDataset(InMemoryDataset):
284
  return out.float()
285
 
286
  def process(self, size= None):
287
- '''
288
- Process the dataset. This function will be only run if processed_file_names does not exist in the data folder already.
289
- '''
290
- # mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()]
291
- # mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols))
292
- # mols = mols[:size] # i
293
- # indices = range(len(mols))
294
-
295
- smiles = pd.read_csv(self.raw_files, header=None)[0].tolist()
296
- self._generate_encoders_decoders(smiles)
297
 
298
- # pbar.set_description(f'Processing chembl dataset')
299
- # max_length = max(mol.GetNumAtoms() for mol in mols)
 
300
  data_list = []
301
- max_length = min(self.max_atom_size_in_data, self.max_atom)
302
  self.m_dim = len(self.atom_decoder_m)
303
- # for idx in indices:
304
- for smiles in tqdm(smiles, desc='Processing chembl dataset', total=len(smiles)):
305
- # mol = mols[idx]
306
-
307
- mol = Chem.MolFromSmiles(smile)
308
-
309
- # filter by max atom size
310
- if mol.GetNumAtoms() > max_length:
311
- continue
312
-
313
- A = self.generate_adjacency_matrix(mol, connected=True, max_length=max_length)
314
  if A is not None:
315
 
316
 
317
- x = torch.from_numpy(self.generate_node_features(mol, max_length=max_length)).to(torch.long).view(1, -1)
318
 
319
  x = self.label2onehot(x,self.m_dim).squeeze()
320
  if self.features:
321
- f = torch.from_numpy(self.generate_additional_features(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
322
  x = torch.concat((x,f), dim=-1)
323
 
324
  adjacency = torch.from_numpy(A)
@@ -335,9 +281,9 @@ class DruggenDataset(InMemoryDataset):
335
  data = self.pre_transform(data)
336
 
337
  data_list.append(data)
338
- # pbar.update(1)
339
 
340
- # pbar.close()
341
 
342
  torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
343
 
@@ -346,4 +292,4 @@ class DruggenDataset(InMemoryDataset):
346
 
347
  if __name__ == '__main__':
348
  data = DruggenDataset("DrugGEN/data")
349
-
 
1
  import pickle
 
 
 
 
2
  import numpy as np
3
+ import torch
4
  from rdkit import Chem
 
5
  from torch_geometric.data import (Data, InMemoryDataset)
6
+ import os.path as osp
7
+ import pickle
8
+ import torch
9
+ from tqdm import tqdm
10
+ import re
11
+ from rdkit import RDLogger
12
  RDLogger.DisableLog('rdApp.*')
13
  class DruggenDataset(InMemoryDataset):
14
 
 
18
  self.raw_files = raw_files
19
  self.max_atom = max_atom
20
  self.features = features
 
21
  super().__init__(root, transform, pre_transform, pre_filter)
22
+ path = osp.join(self.processed_dir, dataset_file)
23
+ self.data, self.slices = torch.load(path)
24
+ self.root = root
25
 
26
+
27
+ @property
28
+ def processed_dir(self):
29
+
30
+ return self.root
31
+
32
  @property
33
  def raw_file_names(self):
34
  return self.raw_files
35
 
36
  @property
37
  def processed_file_names(self):
 
 
 
38
  return self.dataset_file
39
 
40
  def _generate_encoders_decoders(self, data):
41
+
 
 
42
  self.data = data
43
  print('Creating atoms encoder and decoder..')
44
+ atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
46
  self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
47
  self.atom_num_types = len(atom_labels)
48
+ print('Created atoms encoder and decoder with {} atom types and 1 PAD symbol!'.format(
49
+ self.atom_num_types - 1))
50
  print("atom_labels", atom_labels)
51
  print('Creating bonds encoder and decoder..')
52
+ bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
53
+ for mol in self.data
54
+ for bond in mol.GetBonds())))
 
 
 
 
 
 
 
 
55
  print("bond labels", bond_labels)
56
  self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
57
  self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
58
  self.bond_num_types = len(bond_labels)
59
+ print('Created bonds encoder and decoder with {} bond types and 1 PAD symbol!'.format(
60
+ self.bond_num_types - 1))
61
  #dataset_names = str(self.dataset_name)
62
  with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
63
  pickle.dump(self.atom_encoder_m,atom_encoders)
 
76
 
77
 
78
 
79
+ def _genA(self, mol, connected=True, max_length=None):
 
 
 
 
 
 
 
80
 
 
 
 
 
81
  max_length = max_length if max_length is not None else mol.GetNumAtoms()
82
 
83
  A = np.zeros(shape=(max_length, max_length))
 
92
 
93
  return A if connected and (degree > 0).all() else None
94
 
95
+ def _genX(self, mol, max_length=None):
 
 
 
 
 
 
96
 
 
 
 
97
  max_length = max_length if max_length is not None else mol.GetNumAtoms()
98
 
99
  return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
100
  max_length - mol.GetNumAtoms()))
101
 
102
+ def _genF(self, mol, max_length=None):
 
 
103
 
 
 
 
 
 
 
 
104
  max_length = max_length if max_length is not None else mol.GetNumAtoms()
105
 
106
  features = np.array([[*[a.GetDegree() == i for i in range(5)],
 
117
 
118
  return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
119
 
120
+ def decoder_load(self, dictionary_name, file):
121
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + file + '.pkl', 'rb') as f:
122
  return pickle.load(f)
123
 
124
  def drugs_decoder_load(self, dictionary_name):
125
  with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
126
  return pickle.load(f)
127
 
128
+ def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
129
  mol = Chem.RWMol()
130
  RDLogger.DisableLog('rdApp.*')
131
+ atom_decoders = self.decoder_load("atom", file_name)
132
+ bond_decoders = self.decoder_load("bond", file_name)
133
 
134
  for node_label in node_labels:
135
  mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
 
137
  for start, end in zip(*np.nonzero(edge_labels)):
138
  if start > end:
139
  mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
140
+ #mol = self.correct_mol(mol)
141
  if strict:
142
  try:
143
 
 
147
 
148
  return mol
149
 
150
+ def drug_decoder_load(self, dictionary_name, file):
151
 
152
  ''' Loading the atom and bond decoders '''
153
 
154
+ with open("DrugGEN/data/decoders/" + dictionary_name +"_" + file +'.pkl', 'rb') as f:
155
 
156
  return pickle.load(f)
157
+ def matrices2mol_drugs(self, node_labels, edge_labels, strict=True, file_name=None):
158
  mol = Chem.RWMol()
159
  RDLogger.DisableLog('rdApp.*')
160
+ atom_decoders = self.drug_decoder_load("atom", file_name)
161
+ bond_decoders = self.drug_decoder_load("bond", file_name)
162
 
163
  for node_label in node_labels:
164
 
 
167
  for start, end in zip(*np.nonzero(edge_labels)):
168
  if start > end:
169
  mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
170
+ #mol = self.correct_mol(mol)
171
  if strict:
172
  try:
173
  Chem.SanitizeMol(mol)
 
193
 
194
 
195
  def correct_mol(self,x):
196
+ xsm = Chem.MolToSmiles(x, isomericSmiles=True)
197
  mol = x
198
  while True:
199
  flag, atomid_valence = self.check_valency(mol)
 
237
  return out.float()
238
 
239
  def process(self, size= None):
240
+
241
+ mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()]
242
+
243
+ mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols))
244
+ mols = mols[:size]
245
+ indices = range(len(mols))
246
+
247
+ self._generate_encoders_decoders(mols)
248
+
249
+
250
 
251
+ pbar = tqdm(total=len(indices))
252
+ pbar.set_description(f'Processing chembl dataset')
253
+ max_length = max(mol.GetNumAtoms() for mol in mols)
254
  data_list = []
255
+
256
  self.m_dim = len(self.atom_decoder_m)
257
+ for idx in indices:
258
+ mol = mols[idx]
259
+ A = self._genA(mol, connected=True, max_length=max_length)
 
 
 
 
 
 
 
 
260
  if A is not None:
261
 
262
 
263
+ x = torch.from_numpy(self._genX(mol, max_length=max_length)).to(torch.long).view(1, -1)
264
 
265
  x = self.label2onehot(x,self.m_dim).squeeze()
266
  if self.features:
267
+ f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
268
  x = torch.concat((x,f), dim=-1)
269
 
270
  adjacency = torch.from_numpy(A)
 
281
  data = self.pre_transform(data)
282
 
283
  data_list.append(data)
284
+ pbar.update(1)
285
 
286
+ pbar.close()
287
 
288
  torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
289
 
 
292
 
293
  if __name__ == '__main__':
294
  data = DruggenDataset("DrugGEN/data")
295
+