vshirasuna commited on
Commit
60b6403
1 Parent(s): 999fcb8

Added evaluate method and option to save for each epoch in finetune

Browse files
smi-ted/finetune/args.py CHANGED
@@ -305,6 +305,7 @@ def get_parser(parser=None):
305
  parser.add_argument("--model_path", type=str, default="./smi_ted/")
306
  parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
307
  # parser.add_argument('--n_output', type=int, default=1)
 
308
  parser.add_argument("--save_ckpt", type=int, default=1)
309
  parser.add_argument("--start_seed", type=int, default=0)
310
  parser.add_argument("--smi_ted_version", type=str, default="v1")
 
305
  parser.add_argument("--model_path", type=str, default="./smi_ted/")
306
  parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
307
  # parser.add_argument('--n_output', type=int, default=1)
308
+ parser.add_argument("--save_every_epoch", type=int, default=0)
309
  parser.add_argument("--save_ckpt", type=int, default=1)
310
  parser.add_argument("--start_seed", type=int, default=0)
311
  parser.add_argument("--smi_ted_version", type=str, default="v1")
smi-ted/finetune/finetune_classification.py CHANGED
@@ -48,6 +48,7 @@ def main(config):
48
  seed=config.start_seed,
49
  checkpoints_folder=config.checkpoints_folder,
50
  device=device,
 
51
  save_ckpt=bool(config.save_ckpt)
52
  )
53
  trainer.compile(
@@ -56,6 +57,7 @@ def main(config):
56
  loss_fn=loss_function
57
  )
58
  trainer.fit(max_epochs=config.max_epochs)
 
59
 
60
 
61
  if __name__ == '__main__':
 
48
  seed=config.start_seed,
49
  checkpoints_folder=config.checkpoints_folder,
50
  device=device,
51
+ save_every_epoch=bool(config.save_every_epoch),
52
  save_ckpt=bool(config.save_ckpt)
53
  )
54
  trainer.compile(
 
57
  loss_fn=loss_function
58
  )
59
  trainer.fit(max_epochs=config.max_epochs)
60
+ trainer.evaluate()
61
 
62
 
63
  if __name__ == '__main__':
smi-ted/finetune/finetune_classification_multitask.py CHANGED
@@ -80,6 +80,7 @@ def main(config):
80
  seed=config.start_seed,
81
  checkpoints_folder=config.checkpoints_folder,
82
  device=device,
 
83
  save_ckpt=bool(config.save_ckpt)
84
  )
85
  trainer.compile(
@@ -88,6 +89,7 @@ def main(config):
88
  loss_fn=loss_function
89
  )
90
  trainer.fit(max_epochs=config.max_epochs)
 
91
 
92
 
93
  if __name__ == '__main__':
 
80
  seed=config.start_seed,
81
  checkpoints_folder=config.checkpoints_folder,
82
  device=device,
83
+ save_every_epoch=bool(config.save_every_epoch),
84
  save_ckpt=bool(config.save_ckpt)
85
  )
86
  trainer.compile(
 
89
  loss_fn=loss_function
90
  )
91
  trainer.fit(max_epochs=config.max_epochs)
92
+ trainer.evaluate()
93
 
94
 
95
  if __name__ == '__main__':
smi-ted/finetune/finetune_regression.py CHANGED
@@ -50,6 +50,7 @@ def main(config):
50
  seed=config.start_seed,
51
  checkpoints_folder=config.checkpoints_folder,
52
  device=device,
 
53
  save_ckpt=bool(config.save_ckpt)
54
  )
55
  trainer.compile(
@@ -58,6 +59,7 @@ def main(config):
58
  loss_fn=loss_function
59
  )
60
  trainer.fit(max_epochs=config.max_epochs)
 
61
 
62
 
63
  if __name__ == '__main__':
 
50
  seed=config.start_seed,
51
  checkpoints_folder=config.checkpoints_folder,
52
  device=device,
53
+ save_every_epoch=bool(config.save_every_epoch),
54
  save_ckpt=bool(config.save_ckpt)
55
  )
56
  trainer.compile(
 
59
  loss_fn=loss_function
60
  )
61
  trainer.fit(max_epochs=config.max_epochs)
62
+ trainer.evaluate()
63
 
64
 
65
  if __name__ == '__main__':
smi-ted/finetune/trainers.py CHANGED
@@ -25,7 +25,7 @@ from utils import RMSE, sensitivity, specificity
25
  class Trainer:
26
 
27
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
28
- target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_ckpt=True, device='cpu'):
29
  # data
30
  self.df_train = raw_data[0]
31
  self.df_valid = raw_data[1]
@@ -40,6 +40,7 @@ class Trainer:
40
  self.target_metric = target_metric
41
  self.seed = seed
42
  self.checkpoints_folder = checkpoints_folder
 
43
  self.save_ckpt = save_ckpt
44
  self.device = device
45
  self._set_seed(seed)
@@ -81,8 +82,7 @@ class Trainer:
81
  self._print_configuration()
82
 
83
  def fit(self, max_epochs=500):
84
- best_vloss = 1000
85
- best_vmetric = -1
86
 
87
  for epoch in range(1, max_epochs+1):
88
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
@@ -91,47 +91,47 @@ class Trainer:
91
  self.model.to(self.device)
92
  self.model.train()
93
  train_loss = self._train_one_epoch()
94
- print(f'Training loss: {round(train_loss, 6)}')
95
 
96
- # Evaluate the model
97
  self.model.eval()
98
  val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader)
99
- tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader)
100
-
101
- print(f"Valid loss: {round(val_loss, 6)}")
102
  for m in val_metrics.keys():
103
  print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
104
- print("-"*32)
105
- print(f"Test loss: {round(tst_loss, 6)}")
106
- for m in tst_metrics.keys():
107
- print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
108
 
109
  ############################### Save Finetune checkpoint #######################################
110
- if (val_loss < best_vloss) and self.save_ckpt:
111
  # remove old checkpoint
112
- if best_vmetric != -1:
113
- os.remove(os.path.join(self.checkpoints_folder, filename))
114
 
115
  # filename
116
  model_name = f'{str(self.model)}-Finetune'
117
- metric = round(tst_metrics[self.target_metric], 4)
118
- filename = f"{model_name}_epoch={epoch}_{self.dataset_name}_seed{self.seed}_{self.target_metric}={metric}.pt"
119
 
120
  # save checkpoint
121
  print('Saving checkpoint...')
122
- self._save_checkpoint(epoch, filename)
123
-
124
- # save predictions
125
- pd.DataFrame(tst_preds).to_csv(
126
- os.path.join(
127
- self.checkpoints_folder,
128
- f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
129
- index=False
130
- )
131
 
132
  # update best loss
133
  best_vloss = val_loss
134
- best_vmetric = metric
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def _train_one_epoch(self):
137
  raise NotImplementedError
@@ -153,6 +153,11 @@ class Trainer:
153
  print('Valid size:\t', self.df_valid.shape[0])
154
  print('Test size:\t', self.df_test.shape[0])
155
 
 
 
 
 
 
156
  def _save_checkpoint(self, current_epoch, filename):
157
  if not os.path.exists(self.checkpoints_folder):
158
  os.makedirs(self.checkpoints_folder)
@@ -198,14 +203,14 @@ class Trainer:
198
  class TrainerRegressor(Trainer):
199
 
200
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
201
- target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_ckpt=True, device='cpu'):
202
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
203
- target_metric, seed, checkpoints_folder, save_ckpt, device)
204
 
205
  def _train_one_epoch(self):
206
  running_loss = 0.0
207
 
208
- for data in tqdm(self.train_loader):
209
  # Every data instance is an input + label pair
210
  smiles, targets = data
211
  targets = targets.clone().detach().to(self.device)
@@ -227,6 +232,11 @@ class TrainerRegressor(Trainer):
227
  # print statistics
228
  running_loss += loss.item()
229
 
 
 
 
 
 
230
  return running_loss / len(self.train_loader)
231
 
232
  def _validate_one_epoch(self, data_loader):
@@ -235,7 +245,7 @@ class TrainerRegressor(Trainer):
235
  running_loss = 0.0
236
 
237
  with torch.no_grad():
238
- for data in tqdm(data_loader):
239
  # Every data instance is an input + label pair
240
  smiles, targets = data
241
  targets = targets.clone().detach().to(self.device)
@@ -253,6 +263,11 @@ class TrainerRegressor(Trainer):
253
  # print statistics
254
  running_loss += loss.item()
255
 
 
 
 
 
 
256
  # Put together predictions and labels from batches
257
  preds = torch.cat(data_preds, dim=0).cpu().numpy()
258
  tgts = torch.cat(data_targets, dim=0).cpu().numpy()
@@ -271,20 +286,20 @@ class TrainerRegressor(Trainer):
271
  'spearman': spearman,
272
  }
273
 
274
- return preds, running_loss / len(self.train_loader), metrics
275
 
276
 
277
  class TrainerClassifier(Trainer):
278
 
279
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
280
- target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_ckpt=True, device='cpu'):
281
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
282
- target_metric, seed, checkpoints_folder, save_ckpt, device)
283
 
284
  def _train_one_epoch(self):
285
  running_loss = 0.0
286
 
287
- for data in tqdm(self.train_loader):
288
  # Every data instance is an input + label pair
289
  smiles, targets = data
290
  targets = targets.clone().detach().to(self.device)
@@ -306,6 +321,11 @@ class TrainerClassifier(Trainer):
306
  # print statistics
307
  running_loss += loss.item()
308
 
 
 
 
 
 
309
  return running_loss / len(self.train_loader)
310
 
311
  def _validate_one_epoch(self, data_loader):
@@ -314,7 +334,7 @@ class TrainerClassifier(Trainer):
314
  running_loss = 0.0
315
 
316
  with torch.no_grad():
317
- for data in tqdm(data_loader):
318
  # Every data instance is an input + label pair
319
  smiles, targets = data
320
  targets = targets.clone().detach().to(self.device)
@@ -332,6 +352,11 @@ class TrainerClassifier(Trainer):
332
  # print statistics
333
  running_loss += loss.item()
334
 
 
 
 
 
 
335
  # Put together predictions and labels from batches
336
  preds = torch.cat(data_preds, dim=0).cpu().numpy()
337
  tgts = torch.cat(data_targets, dim=0).cpu().numpy()
@@ -366,15 +391,15 @@ class TrainerClassifier(Trainer):
366
  'specificity': sp,
367
  }
368
 
369
- return preds, running_loss / len(self.train_loader), metrics
370
 
371
 
372
  class TrainerClassifierMultitask(Trainer):
373
 
374
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
375
- target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_ckpt=True, device='cpu'):
376
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
377
- target_metric, seed, checkpoints_folder, save_ckpt, device)
378
 
379
  def _prepare_data(self):
380
  # normalize dataset
@@ -409,7 +434,7 @@ class TrainerClassifierMultitask(Trainer):
409
  def _train_one_epoch(self):
410
  running_loss = 0.0
411
 
412
- for data in tqdm(self.train_loader):
413
  # Every data instance is an input + label pair + mask
414
  smiles, targets, target_masks = data
415
  targets = targets.clone().detach().to(self.device)
@@ -432,6 +457,11 @@ class TrainerClassifierMultitask(Trainer):
432
  # print statistics
433
  running_loss += loss.item()
434
 
 
 
 
 
 
435
  return running_loss / len(self.train_loader)
436
 
437
  def _validate_one_epoch(self, data_loader):
@@ -441,7 +471,7 @@ class TrainerClassifierMultitask(Trainer):
441
  running_loss = 0.0
442
 
443
  with torch.no_grad():
444
- for data in tqdm(data_loader):
445
  # Every data instance is an input + label pair + mask
446
  smiles, targets, target_masks = data
447
  targets = targets.clone().detach().to(self.device)
@@ -461,6 +491,11 @@ class TrainerClassifierMultitask(Trainer):
461
  # print statistics
462
  running_loss += loss.item()
463
 
 
 
 
 
 
464
  # Put together predictions and labels from batches
465
  preds = torch.cat(data_preds, dim=0)
466
  tgts = torch.cat(data_targets, dim=0)
@@ -513,4 +548,4 @@ class TrainerClassifierMultitask(Trainer):
513
  'specificity': average_sp.item(),
514
  }
515
 
516
- return preds, running_loss / len(self.train_loader), metrics
 
25
  class Trainer:
26
 
27
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
28
+ target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
29
  # data
30
  self.df_train = raw_data[0]
31
  self.df_valid = raw_data[1]
 
40
  self.target_metric = target_metric
41
  self.seed = seed
42
  self.checkpoints_folder = checkpoints_folder
43
+ self.save_every_epoch = save_every_epoch
44
  self.save_ckpt = save_ckpt
45
  self.device = device
46
  self._set_seed(seed)
 
82
  self._print_configuration()
83
 
84
  def fit(self, max_epochs=500):
85
+ best_vloss = float('inf')
 
86
 
87
  for epoch in range(1, max_epochs+1):
88
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
 
91
  self.model.to(self.device)
92
  self.model.train()
93
  train_loss = self._train_one_epoch()
 
94
 
95
+ # validation
96
  self.model.eval()
97
  val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader)
 
 
 
98
  for m in val_metrics.keys():
99
  print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
 
 
 
 
100
 
101
  ############################### Save Finetune checkpoint #######################################
102
+ if ((val_loss < best_vloss) or self.save_every_epoch) and self.save_ckpt:
103
  # remove old checkpoint
104
+ if best_vloss != float('inf') and not self.save_every_epoch:
105
+ os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
106
 
107
  # filename
108
  model_name = f'{str(self.model)}-Finetune'
109
+ self.last_filename = f"{model_name}_epoch={epoch}_{self.dataset_name}_seed{self.seed}_valloss={round(val_loss, 4)}.pt"
 
110
 
111
  # save checkpoint
112
  print('Saving checkpoint...')
113
+ self._save_checkpoint(epoch, self.last_filename)
 
 
 
 
 
 
 
 
114
 
115
  # update best loss
116
  best_vloss = val_loss
117
+
118
+ def evaluate(self):
119
+ print("\n=====Test Evaluation=====")
120
+ self._load_checkpoint(self.last_filename)
121
+ self.model.eval()
122
+ tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader)
123
+
124
+ # show metrics
125
+ for m in tst_metrics.keys():
126
+ print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
127
+
128
+ # save predictions
129
+ pd.DataFrame(tst_preds).to_csv(
130
+ os.path.join(
131
+ self.checkpoints_folder,
132
+ f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
133
+ index=False
134
+ )
135
 
136
  def _train_one_epoch(self):
137
  raise NotImplementedError
 
153
  print('Valid size:\t', self.df_valid.shape[0])
154
  print('Test size:\t', self.df_test.shape[0])
155
 
156
+ def _load_checkpoint(self, filename):
157
+ ckpt_path = os.path.join(self.checkpoints_folder, filename)
158
+ ckpt_dict = torch.load(ckpt_path, map_location='cpu')
159
+ self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
160
+
161
  def _save_checkpoint(self, current_epoch, filename):
162
  if not os.path.exists(self.checkpoints_folder):
163
  os.makedirs(self.checkpoints_folder)
 
203
  class TrainerRegressor(Trainer):
204
 
205
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
206
+ target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
207
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
208
+ target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
209
 
210
  def _train_one_epoch(self):
211
  running_loss = 0.0
212
 
213
+ for idx, data in enumerate(pbar := tqdm(self.train_loader)):
214
  # Every data instance is an input + label pair
215
  smiles, targets = data
216
  targets = targets.clone().detach().to(self.device)
 
232
  # print statistics
233
  running_loss += loss.item()
234
 
235
+ # progress bar
236
+ pbar.set_description('[TRAINING]')
237
+ pbar.set_postfix(loss=running_loss/(idx+1))
238
+ pbar.refresh()
239
+
240
  return running_loss / len(self.train_loader)
241
 
242
  def _validate_one_epoch(self, data_loader):
 
245
  running_loss = 0.0
246
 
247
  with torch.no_grad():
248
+ for idx, data in enumerate(pbar := tqdm(data_loader)):
249
  # Every data instance is an input + label pair
250
  smiles, targets = data
251
  targets = targets.clone().detach().to(self.device)
 
263
  # print statistics
264
  running_loss += loss.item()
265
 
266
+ # progress bar
267
+ pbar.set_description('[EVALUATION]')
268
+ pbar.set_postfix(loss=running_loss/(idx+1))
269
+ pbar.refresh()
270
+
271
  # Put together predictions and labels from batches
272
  preds = torch.cat(data_preds, dim=0).cpu().numpy()
273
  tgts = torch.cat(data_targets, dim=0).cpu().numpy()
 
286
  'spearman': spearman,
287
  }
288
 
289
+ return preds, running_loss / len(data_loader), metrics
290
 
291
 
292
  class TrainerClassifier(Trainer):
293
 
294
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
295
+ target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
296
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
297
+ target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
298
 
299
  def _train_one_epoch(self):
300
  running_loss = 0.0
301
 
302
+ for idx, data in enumerate(pbar := tqdm(self.train_loader)):
303
  # Every data instance is an input + label pair
304
  smiles, targets = data
305
  targets = targets.clone().detach().to(self.device)
 
321
  # print statistics
322
  running_loss += loss.item()
323
 
324
+ # progress bar
325
+ pbar.set_description('[TRAINING]')
326
+ pbar.set_postfix(loss=running_loss/(idx+1))
327
+ pbar.refresh()
328
+
329
  return running_loss / len(self.train_loader)
330
 
331
  def _validate_one_epoch(self, data_loader):
 
334
  running_loss = 0.0
335
 
336
  with torch.no_grad():
337
+ for idx, data in enumerate(pbar := tqdm(data_loader)):
338
  # Every data instance is an input + label pair
339
  smiles, targets = data
340
  targets = targets.clone().detach().to(self.device)
 
352
  # print statistics
353
  running_loss += loss.item()
354
 
355
+ # progress bar
356
+ pbar.set_description('[EVALUATION]')
357
+ pbar.set_postfix(loss=running_loss/(idx+1))
358
+ pbar.refresh()
359
+
360
  # Put together predictions and labels from batches
361
  preds = torch.cat(data_preds, dim=0).cpu().numpy()
362
  tgts = torch.cat(data_targets, dim=0).cpu().numpy()
 
391
  'specificity': sp,
392
  }
393
 
394
+ return preds, running_loss / len(data_loader), metrics
395
 
396
 
397
  class TrainerClassifierMultitask(Trainer):
398
 
399
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
400
+ target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
401
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
402
+ target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
403
 
404
  def _prepare_data(self):
405
  # normalize dataset
 
434
  def _train_one_epoch(self):
435
  running_loss = 0.0
436
 
437
+ for idx, data in enumerate(pbar := tqdm(self.train_loader)):
438
  # Every data instance is an input + label pair + mask
439
  smiles, targets, target_masks = data
440
  targets = targets.clone().detach().to(self.device)
 
457
  # print statistics
458
  running_loss += loss.item()
459
 
460
+ # progress bar
461
+ pbar.set_description('[TRAINING]')
462
+ pbar.set_postfix(loss=running_loss/(idx+1))
463
+ pbar.refresh()
464
+
465
  return running_loss / len(self.train_loader)
466
 
467
  def _validate_one_epoch(self, data_loader):
 
471
  running_loss = 0.0
472
 
473
  with torch.no_grad():
474
+ for idx, data in enumerate(pbar := tqdm(data_loader)):
475
  # Every data instance is an input + label pair + mask
476
  smiles, targets, target_masks = data
477
  targets = targets.clone().detach().to(self.device)
 
491
  # print statistics
492
  running_loss += loss.item()
493
 
494
+ # progress bar
495
+ pbar.set_description('[EVALUATION]')
496
+ pbar.set_postfix(loss=running_loss/(idx+1))
497
+ pbar.refresh()
498
+
499
  # Put together predictions and labels from batches
500
  preds = torch.cat(data_preds, dim=0)
501
  tgts = torch.cat(data_targets, dim=0)
 
548
  'specificity': average_sp.item(),
549
  }
550
 
551
+ return preds, running_loss / len(data_loader), metrics