mgyigit commited on
Commit
72764c1
·
verified ·
1 Parent(s): 3a3578c

Update src/data/dataset.py

Browse files
Files changed (1) hide show
  1. src/data/dataset.py +11 -8
src/data/dataset.py CHANGED
@@ -89,11 +89,11 @@ class DruggenDataset(InMemoryDataset):
89
  smiles_list (list): List of SMILES strings.
90
 
91
  Returns:
92
- max_length (int): Maximum number of atoms found in the filtered molecules.
93
  filtered_smiles (list): List of valid SMILES strings.
94
  """
95
- max_length = 0
96
  filtered_smiles = []
 
97
  for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
98
  mol = Chem.MolFromSmiles(smiles)
99
  if mol is None:
@@ -113,8 +113,9 @@ class DruggenDataset(InMemoryDataset):
113
  continue
114
 
115
  filtered_smiles.append(smiles)
116
- max_length = max(max_length, molecule_size)
117
- return max_length, filtered_smiles
 
118
 
119
  def _genA(self, mol, connected=True, max_length=None):
120
  """
@@ -290,20 +291,22 @@ class DruggenDataset(InMemoryDataset):
290
  """
291
  # Read raw SMILES from file (assuming CSV with no header)
292
  smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
293
- max_length, filtered_smiles = self._filter_smiles(smiles_list)
 
 
294
  data_list = []
295
  self.m_dim = len(self.atom_decoder_m)
296
  for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
297
  mol = Chem.MolFromSmiles(smiles)
298
- A = self._genA(mol, connected=True, max_length=max_length)
299
  if A is not None:
300
- x_array = self._genX(mol, max_length=max_length)
301
  if x_array is None:
302
  continue
303
  x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
304
  x = label2onehot(x, self.m_dim).squeeze()
305
  if self.features:
306
- f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
307
  x = torch.concat((x, f), dim=-1)
308
  adjacency = torch.from_numpy(A)
309
  edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
 
89
  smiles_list (list): List of SMILES strings.
90
 
91
  Returns:
92
+ num_smiles (int): Number of filtered smiles
93
  filtered_smiles (list): List of valid SMILES strings.
94
  """
 
95
  filtered_smiles = []
96
+ num_smiles = 0
97
  for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
98
  mol = Chem.MolFromSmiles(smiles)
99
  if mol is None:
 
113
  continue
114
 
115
  filtered_smiles.append(smiles)
116
+ num_smiles += 1
117
+
118
+ return num_smiles, filtered_smiles
119
 
120
  def _genA(self, mol, connected=True, max_length=None):
121
  """
 
291
  """
292
  # Read raw SMILES from file (assuming CSV with no header)
293
  smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
294
+ num_smiles, filtered_smiles = self._filter_smiles(smiles_list)
295
+ self.num_smiles = num_smiles
296
+
297
  data_list = []
298
  self.m_dim = len(self.atom_decoder_m)
299
  for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
300
  mol = Chem.MolFromSmiles(smiles)
301
+ A = self._genA(mol, connected=True, max_length=self.max_atom)
302
  if A is not None:
303
+ x_array = self._genX(mol, max_length=self.max_atom)
304
  if x_array is None:
305
  continue
306
  x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
307
  x = label2onehot(x, self.m_dim).squeeze()
308
  if self.features:
309
+ f = torch.from_numpy(self._genF(mol, max_length=self.max_atom)).to(torch.long).view(x.shape[0], -1)
310
  x = torch.concat((x, f), dim=-1)
311
  adjacency = torch.from_numpy(A)
312
  edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()