Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import yaml | |
import torch | |
import random | |
from tqdm import tqdm | |
from pprint import pprint | |
from torch.utils import data | |
from dataset import load_dataset | |
from loss import get_loss | |
from model import load_model | |
from model.common import freeze_weights | |
from trainer import AbstractTrainer | |
from trainer.utils import AccMeter, AUCMeter, AverageMeter, Logger, center_print | |
class ExpTester(AbstractTrainer): | |
def __init__(self, config, stage="Test"): | |
super(ExpTester, self).__init__(config, stage) | |
if torch.cuda.is_available() and self.device is not None: | |
print(f"Using cuda device: {self.device}.") | |
self.gpu = True | |
self.model = self.model.to(self.device) | |
else: | |
print("Using cpu device.") | |
self.device = torch.device("cpu") | |
def _initiated_settings(self, model_cfg=None, data_cfg=None, config_cfg=None): | |
self.gpu = False | |
self.device = config_cfg.get("device", None) | |
def _train_settings(self, model_cfg=None, data_cfg=None, config_cfg=None): | |
# Not used. | |
raise NotImplementedError("The function is not intended to be used here.") | |
def _test_settings(self, model_cfg=None, data_cfg=None, config_cfg=None): | |
# load test dataset | |
test_dataset = data_cfg["file"] | |
branch = data_cfg["test_branch"] | |
name = data_cfg["name"] | |
with open(test_dataset, "r") as f: | |
options = yaml.load(f, Loader=yaml.FullLoader) | |
test_options = options[branch] | |
self.test_set = load_dataset(name)(test_options) | |
# wrapped with data loader | |
self.test_batch_size = data_cfg["test_batch_size"] | |
self.test_loader = data.DataLoader(self.test_set, shuffle=False, | |
batch_size=self.test_batch_size) | |
self.run_id = config_cfg["id"] | |
self.ckpt_fold = config_cfg.get("ckpt_fold", "runs") | |
self.dir = os.path.join(self.ckpt_fold, self.model_name, self.run_id) | |
# load model | |
self.num_classes = model_cfg["num_classes"] | |
self.model = load_model(self.model_name)(**model_cfg) | |
# load loss | |
self.loss_criterion = get_loss(config_cfg.get("loss", None)) | |
# redirect the std out stream | |
sys.stdout = Logger(os.path.join(self.dir, "test_result.txt")) | |
print('Run dir: {}'.format(self.dir)) | |
center_print('Test configurations begins') | |
pprint(self.config) | |
pprint(test_options) | |
center_print('Test configurations ends') | |
self.ckpt = config_cfg.get("ckpt", "best_model") | |
self._load_ckpt(best=True, train=False) | |
def _save_ckpt(self, step, best=False): | |
# Not used. | |
raise NotImplementedError("The function is not intended to be used here.") | |
def _load_ckpt(self, best=False, train=False): | |
load_dir = os.path.join(self.dir, self.ckpt + ".bin" if best else "latest_model.bin") | |
load_dict = torch.load(load_dir, map_location=self.device) | |
self.start_step = load_dict["step"] | |
self.best_step = load_dict["best_step"] | |
self.best_metric = load_dict.get("best_metric", None) | |
if self.best_metric is None: | |
self.best_metric = load_dict.get("best_acc") | |
self.eval_metric = load_dict.get("eval_metric", None) | |
if self.eval_metric is None: | |
self.eval_metric = load_dict.get("Acc") | |
self.model.load_state_dict(load_dict["model"]) | |
print(f"Loading checkpoint from {load_dir}, best step: {self.best_step}, " | |
f"best {self.eval_metric}: {round(self.best_metric.item(), 4)}.") | |
def train(self): | |
# Not used. | |
raise NotImplementedError("The function is not intended to be used here.") | |
def validate(self, epoch, step, timer, writer): | |
# Not used. | |
raise NotImplementedError("The function is not intended to be used here.") | |
def test(self, display_images=False): | |
freeze_weights(self.model) | |
t_idx = random.randint(1, len(self.test_loader) + 1) | |
self.fixed_randomness() # for reproduction | |
acc = AccMeter() | |
auc = AUCMeter() | |
logloss = AverageMeter() | |
test_generator = tqdm(enumerate(self.test_loader, 1)) | |
categories = self.test_loader.dataset.categories | |
for idx, test_data in test_generator: | |
self.model.eval() | |
I, Y = test_data | |
I = self.test_loader.dataset.load_item(I) | |
if self.gpu: | |
in_I, Y = self.to_device((I, Y)) | |
else: | |
in_I, Y = (I, Y) | |
Y_pre = self.model(in_I) | |
# for BCE Setting: | |
if self.num_classes == 1: | |
Y_pre = Y_pre.squeeze() | |
loss = self.loss_criterion(Y_pre, Y.float()) | |
Y_pre = torch.sigmoid(Y_pre) | |
else: | |
loss = self.loss_criterion(Y_pre, Y) | |
acc.update(Y_pre, Y, use_bce=self.num_classes == 1) | |
auc.update(Y_pre, Y, use_bce=self.num_classes == 1) | |
logloss.update(loss.item()) | |
test_generator.set_description("Test %d/%d" % (idx, len(self.test_loader))) | |
if display_images and idx == t_idx: | |
# show images | |
images = I[:4] | |
pred = Y_pre[:4] | |
gt = Y[:4] | |
self.plot_figure(images, pred, gt, 2, categories) | |
print("Test, FINAL LOSS %.4f, FINAL ACC %.4f, FINAL AUC %.4f" % | |
(logloss.avg, acc.mean_acc(), auc.mean_auc())) | |
auc.curve(self.dir) | |