import torch import os from collections import OrderedDict import random from . import model_utils class BaseModel: """docstring for BaseModel""" def __init__(self): super(BaseModel, self).__init__() self.name = "Base" def initialize(self, opt): self.opt = opt self.gpu_ids = self.opt.gpu_ids self.device = torch.device('cuda:%d' % self.gpu_ids[0] if self.gpu_ids else 'cpu') self.is_train = self.opt.mode == "train" # inherit to define network model self.models_name = [] def setup(self): # print("%s with Model [%s]" % (self.opt.mode.capitalize(), self.name)) if self.is_train: self.set_train() # define loss function self.criterionGAN = model_utils.GANLoss(gan_type=self.opt.gan_type).to(self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionMSE = torch.nn.MSELoss().to(self.device) self.criterionTV = model_utils.TVLoss().to(self.device) torch.nn.DataParallel(self.criterionGAN, self.gpu_ids) torch.nn.DataParallel(self.criterionL1, self.gpu_ids) torch.nn.DataParallel(self.criterionMSE, self.gpu_ids) torch.nn.DataParallel(self.criterionTV, self.gpu_ids) # inherit to set up train/val/test status self.losses_name = [] self.optims = [] self.schedulers = [] else: self.set_eval() def set_eval(self): print("Set model to Test state.") for name in self.models_name: if isinstance(name, str): net = getattr(self, 'net_' + name) if True: net.eval() print("Set net_%s to EVAL." % name) else: net.train() self.is_train = False def set_train(self): print("Set model to Train state.") for name in self.models_name: if isinstance(name, str): net = getattr(self, 'net_' + name) net.train() print("Set net_%s to TRAIN." % name) self.is_train = True def set_requires_grad(self, parameters, requires_grad=False): if not isinstance(parameters, list): parameters = [parameters] for param in parameters: if param is not None: param.requires_grad = requires_grad def get_latest_visuals(self, visuals_name): visual_ret = OrderedDict() for name in visuals_name: if isinstance(name, str) and hasattr(self, name): visual_ret[name] = getattr(self, name) return visual_ret def get_latest_losses(self, losses_name): errors_ret = OrderedDict() for name in losses_name: if isinstance(name, str): cur_loss = float(getattr(self, 'loss_' + name)) # cur_loss_lambda = 1. if len(losses_name) == 1 else float(getattr(self.opt, 'lambda_' + name)) # errors_ret[name] = cur_loss * cur_loss_lambda errors_ret[name] = cur_loss return errors_ret def feed_batch(self, batch): pass def forward(self): pass def optimize_paras(self): pass def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optims[0].param_groups[0]['lr'] return lr def save_ckpt(self, epoch, models_name): for name in models_name: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = os.path.join(self.opt.ckpt_dir, save_filename) net = getattr(self, 'net_' + name) # save cpu params, so that it can be used in other GPU settings if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.to(self.gpu_ids[0]) net = torch.nn.DataParallel(net, self.gpu_ids) else: torch.save(net.cpu().state_dict(), save_path) def load_ckpt(self, epoch, models_name): # print(models_name) for name in models_name: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) # load_path = os.path.join(self.opt.ckpt_dir, load_filename) # assert os.path.isfile(load_path), "File '%s' does not exist." % load_path # pretrained_state_dict = torch.load(load_path, map_location=str(self.device)) pretrained_state_dict = torch.load('checkpoints/30_net_gen.pth', map_location=str('cuda:0')) if hasattr(pretrained_state_dict, '_metadata'): del pretrained_state_dict._metadata net = getattr(self, 'net_' + name) if isinstance(net, torch.nn.DataParallel): net = net.module # load only existing keys pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k in net.state_dict()} # for k, v in pretrained_state_dict.items(): # print(k) # assert False net.load_state_dict(pretrained_dict) print("[Info] Successfully load trained weights for net_%s." % name) def clean_ckpt(self, epoch, models_name): for name in models_name: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.opt.ckpt_dir, load_filename) if os.path.isfile(load_path): os.remove(load_path) def gradient_penalty(self, input_img, generate_img): # interpolate sample alpha = torch.rand(input_img.size(0), 1, 1, 1).to(self.device) inter_img = (alpha * input_img.data + (1 - alpha) * generate_img.data).requires_grad_(True) inter_img_prob, _ = self.net_dis(inter_img) # computer gradient penalty: x: inter_img, y: inter_img_prob # (L2_norm(dy/dx) - 1)**2 dydx = torch.autograd.grad(outputs=inter_img_prob, inputs=inter_img, grad_outputs=torch.ones(inter_img_prob.size()).to(self.device), retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) return torch.mean((dydx_l2norm - 1) ** 2)