P-DFD / trainer /abstract_trainer.py
mrneuralnet's picture
Initial commit
982865f
import os
import torch
import random
from collections import OrderedDict
from torchvision.utils import make_grid
LEGAL_METRIC = ['Acc', 'AUC', 'LogLoss']
class AbstractTrainer(object):
def __init__(self, config, stage="Train"):
feasible_stage = ["Train", "Test"]
if stage not in feasible_stage:
raise ValueError(f"stage should be in {feasible_stage}, but found '{stage}'")
self.config = config
model_cfg = config.get("model", None)
data_cfg = config.get("data", None)
config_cfg = config.get("config", None)
self.model_name = model_cfg.pop("name")
self.gpu = None
self.dir = None
self.debug = None
self.device = None
self.resume = None
self.local_rank = None
self.num_classes = None
self.best_metric = 0.0
self.best_step = 1
self.start_step = 1
self._initiated_settings(model_cfg, data_cfg, config_cfg)
if stage == 'Train':
self._train_settings(model_cfg, data_cfg, config_cfg)
if stage == 'Test':
self._test_settings(model_cfg, data_cfg, config_cfg)
def _initiated_settings(self, model_cfg, data_cfg, config_cfg):
raise NotImplementedError("Not implemented in abstract class.")
def _train_settings(self, model_cfg, data_cfg, config_cfg):
raise NotImplementedError("Not implemented in abstract class.")
def _test_settings(self, model_cfg, data_cfg, config_cfg):
raise NotImplementedError("Not implemented in abstract class.")
def _save_ckpt(self, step, best=False):
raise NotImplementedError("Not implemented in abstract class.")
def _load_ckpt(self, best=False, train=False):
raise NotImplementedError("Not implemented in abstract class.")
def to_device(self, items):
return [obj.to(self.device) for obj in items]
@staticmethod
def fixed_randomness():
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
def train(self):
raise NotImplementedError("Not implemented in abstract class.")
def validate(self, epoch, step, timer, writer):
raise NotImplementedError("Not implemented in abstract class.")
def test(self):
raise NotImplementedError("Not implemented in abstract class.")
def plot_figure(self, images, pred, gt, nrow, categories=None, show=True):
import matplotlib.pyplot as plt
plot = make_grid(
images, nrow, padding=4, normalize=True, scale_each=True, pad_value=1)
if self.num_classes == 1:
pred = (pred >= 0.5).cpu().numpy()
else:
pred = pred.argmax(1).cpu().numpy()
gt = gt.cpu().numpy()
if categories is not None:
pred = [categories[i] for i in pred]
gt = [categories[i] for i in gt]
plot = plot.permute([1, 2, 0])
plot = plot.cpu().numpy()
ret = plt.figure()
plt.imshow(plot)
plt.title("pred: %s\ngt: %s" % (pred, gt))
plt.axis("off")
if show:
plt.savefig(os.path.join(self.dir, "test_image.png"), dpi=300)
plt.show()
plt.close()
else:
plt.close()
return ret