Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import random | |
from collections import OrderedDict | |
from torchvision.utils import make_grid | |
LEGAL_METRIC = ['Acc', 'AUC', 'LogLoss'] | |
class AbstractTrainer(object): | |
def __init__(self, config, stage="Train"): | |
feasible_stage = ["Train", "Test"] | |
if stage not in feasible_stage: | |
raise ValueError(f"stage should be in {feasible_stage}, but found '{stage}'") | |
self.config = config | |
model_cfg = config.get("model", None) | |
data_cfg = config.get("data", None) | |
config_cfg = config.get("config", None) | |
self.model_name = model_cfg.pop("name") | |
self.gpu = None | |
self.dir = None | |
self.debug = None | |
self.device = None | |
self.resume = None | |
self.local_rank = None | |
self.num_classes = None | |
self.best_metric = 0.0 | |
self.best_step = 1 | |
self.start_step = 1 | |
self._initiated_settings(model_cfg, data_cfg, config_cfg) | |
if stage == 'Train': | |
self._train_settings(model_cfg, data_cfg, config_cfg) | |
if stage == 'Test': | |
self._test_settings(model_cfg, data_cfg, config_cfg) | |
def _initiated_settings(self, model_cfg, data_cfg, config_cfg): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def _train_settings(self, model_cfg, data_cfg, config_cfg): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def _test_settings(self, model_cfg, data_cfg, config_cfg): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def _save_ckpt(self, step, best=False): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def _load_ckpt(self, best=False, train=False): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def to_device(self, items): | |
return [obj.to(self.device) for obj in items] | |
def fixed_randomness(): | |
random.seed(0) | |
torch.manual_seed(0) | |
torch.cuda.manual_seed(0) | |
torch.cuda.manual_seed_all(0) | |
def train(self): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def validate(self, epoch, step, timer, writer): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def test(self): | |
raise NotImplementedError("Not implemented in abstract class.") | |
def plot_figure(self, images, pred, gt, nrow, categories=None, show=True): | |
import matplotlib.pyplot as plt | |
plot = make_grid( | |
images, nrow, padding=4, normalize=True, scale_each=True, pad_value=1) | |
if self.num_classes == 1: | |
pred = (pred >= 0.5).cpu().numpy() | |
else: | |
pred = pred.argmax(1).cpu().numpy() | |
gt = gt.cpu().numpy() | |
if categories is not None: | |
pred = [categories[i] for i in pred] | |
gt = [categories[i] for i in gt] | |
plot = plot.permute([1, 2, 0]) | |
plot = plot.cpu().numpy() | |
ret = plt.figure() | |
plt.imshow(plot) | |
plt.title("pred: %s\ngt: %s" % (pred, gt)) | |
plt.axis("off") | |
if show: | |
plt.savefig(os.path.join(self.dir, "test_image.png"), dpi=300) | |
plt.show() | |
plt.close() | |
else: | |
plt.close() | |
return ret | |