Spaces:
Sleeping
Sleeping
File size: 5,569 Bytes
982865f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import sys
import yaml
import torch
import random
from tqdm import tqdm
from pprint import pprint
from torch.utils import data
from dataset import load_dataset
from loss import get_loss
from model import load_model
from model.common import freeze_weights
from trainer import AbstractTrainer
from trainer.utils import AccMeter, AUCMeter, AverageMeter, Logger, center_print
class ExpTester(AbstractTrainer):
def __init__(self, config, stage="Test"):
super(ExpTester, self).__init__(config, stage)
if torch.cuda.is_available() and self.device is not None:
print(f"Using cuda device: {self.device}.")
self.gpu = True
self.model = self.model.to(self.device)
else:
print("Using cpu device.")
self.device = torch.device("cpu")
def _initiated_settings(self, model_cfg=None, data_cfg=None, config_cfg=None):
self.gpu = False
self.device = config_cfg.get("device", None)
def _train_settings(self, model_cfg=None, data_cfg=None, config_cfg=None):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def _test_settings(self, model_cfg=None, data_cfg=None, config_cfg=None):
# load test dataset
test_dataset = data_cfg["file"]
branch = data_cfg["test_branch"]
name = data_cfg["name"]
with open(test_dataset, "r") as f:
options = yaml.load(f, Loader=yaml.FullLoader)
test_options = options[branch]
self.test_set = load_dataset(name)(test_options)
# wrapped with data loader
self.test_batch_size = data_cfg["test_batch_size"]
self.test_loader = data.DataLoader(self.test_set, shuffle=False,
batch_size=self.test_batch_size)
self.run_id = config_cfg["id"]
self.ckpt_fold = config_cfg.get("ckpt_fold", "runs")
self.dir = os.path.join(self.ckpt_fold, self.model_name, self.run_id)
# load model
self.num_classes = model_cfg["num_classes"]
self.model = load_model(self.model_name)(**model_cfg)
# load loss
self.loss_criterion = get_loss(config_cfg.get("loss", None))
# redirect the std out stream
sys.stdout = Logger(os.path.join(self.dir, "test_result.txt"))
print('Run dir: {}'.format(self.dir))
center_print('Test configurations begins')
pprint(self.config)
pprint(test_options)
center_print('Test configurations ends')
self.ckpt = config_cfg.get("ckpt", "best_model")
self._load_ckpt(best=True, train=False)
def _save_ckpt(self, step, best=False):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def _load_ckpt(self, best=False, train=False):
load_dir = os.path.join(self.dir, self.ckpt + ".bin" if best else "latest_model.bin")
load_dict = torch.load(load_dir, map_location=self.device)
self.start_step = load_dict["step"]
self.best_step = load_dict["best_step"]
self.best_metric = load_dict.get("best_metric", None)
if self.best_metric is None:
self.best_metric = load_dict.get("best_acc")
self.eval_metric = load_dict.get("eval_metric", None)
if self.eval_metric is None:
self.eval_metric = load_dict.get("Acc")
self.model.load_state_dict(load_dict["model"])
print(f"Loading checkpoint from {load_dir}, best step: {self.best_step}, "
f"best {self.eval_metric}: {round(self.best_metric.item(), 4)}.")
def train(self):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def validate(self, epoch, step, timer, writer):
# Not used.
raise NotImplementedError("The function is not intended to be used here.")
def test(self, display_images=False):
freeze_weights(self.model)
t_idx = random.randint(1, len(self.test_loader) + 1)
self.fixed_randomness() # for reproduction
acc = AccMeter()
auc = AUCMeter()
logloss = AverageMeter()
test_generator = tqdm(enumerate(self.test_loader, 1))
categories = self.test_loader.dataset.categories
for idx, test_data in test_generator:
self.model.eval()
I, Y = test_data
I = self.test_loader.dataset.load_item(I)
if self.gpu:
in_I, Y = self.to_device((I, Y))
else:
in_I, Y = (I, Y)
Y_pre = self.model(in_I)
# for BCE Setting:
if self.num_classes == 1:
Y_pre = Y_pre.squeeze()
loss = self.loss_criterion(Y_pre, Y.float())
Y_pre = torch.sigmoid(Y_pre)
else:
loss = self.loss_criterion(Y_pre, Y)
acc.update(Y_pre, Y, use_bce=self.num_classes == 1)
auc.update(Y_pre, Y, use_bce=self.num_classes == 1)
logloss.update(loss.item())
test_generator.set_description("Test %d/%d" % (idx, len(self.test_loader)))
if display_images and idx == t_idx:
# show images
images = I[:4]
pred = Y_pre[:4]
gt = Y[:4]
self.plot_figure(images, pred, gt, 2, categories)
print("Test, FINAL LOSS %.4f, FINAL ACC %.4f, FINAL AUC %.4f" %
(logloss.avg, acc.mean_acc(), auc.mean_auc()))
auc.curve(self.dir)
|