feiyang-cai commited on
Commit
8b9fe11
·
1 Parent(s): c6866a7

add more detail time calcuated

Browse files
Files changed (1) hide show
  1. utils.py +44 -42
utils.py CHANGED
@@ -132,33 +132,32 @@ class DataCollator(object):
132
  return self.sme.augment([molecule])[0]
133
 
134
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
135
-
136
- sources = []
137
- targets = []
138
 
139
- for example in instances:
140
- smiles = example['smiles'].strip()
141
- smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
142
-
143
- # get the properties except the smiles and mol_id cols
144
- #props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
145
- source = f"{self.molecule_start_str}{smiles}{self.end_str}"
146
- sources.append(source)
147
 
148
- # Tokenize
149
- tokenized_sources_with_prompt = self.tokenizer(
150
- sources,
151
- max_length=self.source_max_len,
152
- truncation=True,
153
- add_special_tokens=False,
154
- )
155
- input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
156
- input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
157
 
158
- data_dict = {
159
- 'input_ids': input_ids,
160
- 'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
161
- }
162
 
163
  return data_dict
164
 
@@ -259,13 +258,14 @@ class MolecularPropertyPredictionModel():
259
  def predict(self, valid_df, task_type):
260
 
261
  with calculateDuration("predicting"):
262
- test_dataset = Dataset.from_pandas(valid_df)
263
- # construct the dataloader
264
- test_loader = torch.utils.data.DataLoader(
265
- test_dataset,
266
- batch_size=16,
267
- collate_fn=self.data_collator,
268
- )
 
269
 
270
  # predict
271
  y_pred = []
@@ -302,17 +302,19 @@ class MolecularPropertyPredictionModel():
302
  with calculateDuration("predicting a file"):
303
  # we should add the index first
304
  df = df.reset_index()
305
- # we need to check the SMILES strings are valid, the invalid ones will be moved to the last
306
- valid_idx = []
307
- invalid_idx = []
308
- for idx, smiles in enumerate(df['smiles']):
309
- if Chem.MolFromSmiles(smiles):
310
- valid_idx.append(idx)
311
- else:
312
- invalid_idx.append(idx)
313
- valid_df = df.loc[valid_idx]
314
- # get the smiles list
315
- valid_df_smiles = valid_df['smiles'].tolist()
 
 
316
 
317
  input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
318
  results = self.predict(input_df, task_type)
 
132
  return self.sme.augment([molecule])[0]
133
 
134
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
135
+ with calculateDuration("DataCollator"):
136
+ sources = []
 
137
 
138
+ for example in instances:
139
+ smiles = example['smiles'].strip()
140
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
141
+
142
+ # get the properties except the smiles and mol_id cols
143
+ #props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
144
+ source = f"{self.molecule_start_str}{smiles}{self.end_str}"
145
+ sources.append(source)
146
 
147
+ # Tokenize
148
+ tokenized_sources_with_prompt = self.tokenizer(
149
+ sources,
150
+ max_length=self.source_max_len,
151
+ truncation=True,
152
+ add_special_tokens=False,
153
+ )
154
+ input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
155
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
156
 
157
+ data_dict = {
158
+ 'input_ids': input_ids,
159
+ 'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
160
+ }
161
 
162
  return data_dict
163
 
 
258
  def predict(self, valid_df, task_type):
259
 
260
  with calculateDuration("predicting"):
261
+ with calculateDuration("construct dataloader"):
262
+ test_dataset = Dataset.from_pandas(valid_df)
263
+ # construct the dataloader
264
+ test_loader = torch.utils.data.DataLoader(
265
+ test_dataset,
266
+ batch_size=16,
267
+ collate_fn=self.data_collator,
268
+ )
269
 
270
  # predict
271
  y_pred = []
 
302
  with calculateDuration("predicting a file"):
303
  # we should add the index first
304
  df = df.reset_index()
305
+
306
+ with calculateDuration("pre-checking SMILES"):
307
+ # we need to check the SMILES strings are valid, the invalid ones will be moved to the last
308
+ valid_idx = []
309
+ invalid_idx = []
310
+ for idx, smiles in enumerate(df['smiles']):
311
+ if Chem.MolFromSmiles(smiles):
312
+ valid_idx.append(idx)
313
+ else:
314
+ invalid_idx.append(idx)
315
+ valid_df = df.loc[valid_idx]
316
+ # get the smiles list
317
+ valid_df_smiles = valid_df['smiles'].tolist()
318
 
319
  input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
320
  results = self.predict(input_df, task_type)