P-DFD / trainer /exp_tester.py
mrneuralnet's picture
Initial commit
982865f
import os
import sys
import yaml
import torch
import random
from tqdm import tqdm
from pprint import pprint
from torch.utils import data
from dataset import load_dataset
from loss import get_loss
from model import load_model
from model.common import freeze_weights
from trainer import AbstractTrainer
from trainer.utils import AccMeter, AUCMeter, AverageMeter, Logger, center_print
class ExpTester(AbstractTrainer):
def __init__(self, config, stage="Test"):
super(ExpTester, self).__init__(config, stage)
if torch.cuda.is_available() and self.device is not None:
print(f"Using cuda device: {self.device}.")
self.gpu = True
self.model = self.model.to(self.device)
else:
print("Using cpu device.")
self.device = torch.device("cpu")
def _initiated_settings(self, model_cfg=None, data_cfg=None, config_cfg=None):
self.gpu = False
self.device = config_cfg.get("device", None)
def _train_settings(self, model_cfg=None, data_cfg=None, config_cfg=None):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def _test_settings(self, model_cfg=None, data_cfg=None, config_cfg=None):
# load test dataset
test_dataset = data_cfg["file"]
branch = data_cfg["test_branch"]
name = data_cfg["name"]
with open(test_dataset, "r") as f:
options = yaml.load(f, Loader=yaml.FullLoader)
test_options = options[branch]
self.test_set = load_dataset(name)(test_options)
# wrapped with data loader
self.test_batch_size = data_cfg["test_batch_size"]
self.test_loader = data.DataLoader(self.test_set, shuffle=False,
batch_size=self.test_batch_size)
self.run_id = config_cfg["id"]
self.ckpt_fold = config_cfg.get("ckpt_fold", "runs")
self.dir = os.path.join(self.ckpt_fold, self.model_name, self.run_id)
# load model
self.num_classes = model_cfg["num_classes"]
self.model = load_model(self.model_name)(**model_cfg)
# load loss
self.loss_criterion = get_loss(config_cfg.get("loss", None))
# redirect the std out stream
sys.stdout = Logger(os.path.join(self.dir, "test_result.txt"))
print('Run dir: {}'.format(self.dir))
center_print('Test configurations begins')
pprint(self.config)
pprint(test_options)
center_print('Test configurations ends')
self.ckpt = config_cfg.get("ckpt", "best_model")
self._load_ckpt(best=True, train=False)
def _save_ckpt(self, step, best=False):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def _load_ckpt(self, best=False, train=False):
load_dir = os.path.join(self.dir, self.ckpt + ".bin" if best else "latest_model.bin")
load_dict = torch.load(load_dir, map_location=self.device)
self.start_step = load_dict["step"]
self.best_step = load_dict["best_step"]
self.best_metric = load_dict.get("best_metric", None)
if self.best_metric is None:
self.best_metric = load_dict.get("best_acc")
self.eval_metric = load_dict.get("eval_metric", None)
if self.eval_metric is None:
self.eval_metric = load_dict.get("Acc")
self.model.load_state_dict(load_dict["model"])
print(f"Loading checkpoint from {load_dir}, best step: {self.best_step}, "
f"best {self.eval_metric}: {round(self.best_metric.item(), 4)}.")
def train(self):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def validate(self, epoch, step, timer, writer):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def test(self, display_images=False):
freeze_weights(self.model)
t_idx = random.randint(1, len(self.test_loader) + 1)
self.fixed_randomness() # for reproduction
acc = AccMeter()
auc = AUCMeter()
logloss = AverageMeter()
test_generator = tqdm(enumerate(self.test_loader, 1))
categories = self.test_loader.dataset.categories
for idx, test_data in test_generator:
self.model.eval()
I, Y = test_data
I = self.test_loader.dataset.load_item(I)
if self.gpu:
in_I, Y = self.to_device((I, Y))
else:
in_I, Y = (I, Y)
Y_pre = self.model(in_I)
# for BCE Setting:
if self.num_classes == 1:
Y_pre = Y_pre.squeeze()
loss = self.loss_criterion(Y_pre, Y.float())
Y_pre = torch.sigmoid(Y_pre)
else:
loss = self.loss_criterion(Y_pre, Y)
acc.update(Y_pre, Y, use_bce=self.num_classes == 1)
auc.update(Y_pre, Y, use_bce=self.num_classes == 1)
logloss.update(loss.item())
test_generator.set_description("Test %d/%d" % (idx, len(self.test_loader)))
if display_images and idx == t_idx:
# show images
images = I[:4]
pred = Y_pre[:4]
gt = Y[:4]
self.plot_figure(images, pred, gt, 2, categories)
print("Test, FINAL LOSS %.4f, FINAL ACC %.4f, FINAL AUC %.4f" %
(logloss.avg, acc.mean_acc(), auc.mean_auc()))
auc.curve(self.dir)