Commit
·
8b9fe11
1
Parent(s):
c6866a7
add more detail time calcuated
Browse files
utils.py
CHANGED
@@ -132,33 +132,32 @@ class DataCollator(object):
|
|
132 |
return self.sme.augment([molecule])[0]
|
133 |
|
134 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
135 |
-
|
136 |
-
|
137 |
-
targets = []
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
|
163 |
return data_dict
|
164 |
|
@@ -259,13 +258,14 @@ class MolecularPropertyPredictionModel():
|
|
259 |
def predict(self, valid_df, task_type):
|
260 |
|
261 |
with calculateDuration("predicting"):
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
269 |
|
270 |
# predict
|
271 |
y_pred = []
|
@@ -302,17 +302,19 @@ class MolecularPropertyPredictionModel():
|
|
302 |
with calculateDuration("predicting a file"):
|
303 |
# we should add the index first
|
304 |
df = df.reset_index()
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
|
|
316 |
|
317 |
input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
|
318 |
results = self.predict(input_df, task_type)
|
|
|
132 |
return self.sme.augment([molecule])[0]
|
133 |
|
134 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
135 |
+
with calculateDuration("DataCollator"):
|
136 |
+
sources = []
|
|
|
137 |
|
138 |
+
for example in instances:
|
139 |
+
smiles = example['smiles'].strip()
|
140 |
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
141 |
+
|
142 |
+
# get the properties except the smiles and mol_id cols
|
143 |
+
#props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
|
144 |
+
source = f"{self.molecule_start_str}{smiles}{self.end_str}"
|
145 |
+
sources.append(source)
|
146 |
|
147 |
+
# Tokenize
|
148 |
+
tokenized_sources_with_prompt = self.tokenizer(
|
149 |
+
sources,
|
150 |
+
max_length=self.source_max_len,
|
151 |
+
truncation=True,
|
152 |
+
add_special_tokens=False,
|
153 |
+
)
|
154 |
+
input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
|
155 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
156 |
|
157 |
+
data_dict = {
|
158 |
+
'input_ids': input_ids,
|
159 |
+
'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
|
160 |
+
}
|
161 |
|
162 |
return data_dict
|
163 |
|
|
|
258 |
def predict(self, valid_df, task_type):
|
259 |
|
260 |
with calculateDuration("predicting"):
|
261 |
+
with calculateDuration("construct dataloader"):
|
262 |
+
test_dataset = Dataset.from_pandas(valid_df)
|
263 |
+
# construct the dataloader
|
264 |
+
test_loader = torch.utils.data.DataLoader(
|
265 |
+
test_dataset,
|
266 |
+
batch_size=16,
|
267 |
+
collate_fn=self.data_collator,
|
268 |
+
)
|
269 |
|
270 |
# predict
|
271 |
y_pred = []
|
|
|
302 |
with calculateDuration("predicting a file"):
|
303 |
# we should add the index first
|
304 |
df = df.reset_index()
|
305 |
+
|
306 |
+
with calculateDuration("pre-checking SMILES"):
|
307 |
+
# we need to check the SMILES strings are valid, the invalid ones will be moved to the last
|
308 |
+
valid_idx = []
|
309 |
+
invalid_idx = []
|
310 |
+
for idx, smiles in enumerate(df['smiles']):
|
311 |
+
if Chem.MolFromSmiles(smiles):
|
312 |
+
valid_idx.append(idx)
|
313 |
+
else:
|
314 |
+
invalid_idx.append(idx)
|
315 |
+
valid_df = df.loc[valid_idx]
|
316 |
+
# get the smiles list
|
317 |
+
valid_df_smiles = valid_df['smiles'].tolist()
|
318 |
|
319 |
input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
|
320 |
results = self.predict(input_df, task_type)
|