Spaces:
Build error
Build error
File size: 6,803 Bytes
0b5f327 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import os
from collections import OrderedDict
import random
from . import model_utils
class BaseModel:
"""docstring for BaseModel"""
def __init__(self):
super(BaseModel, self).__init__()
self.name = "Base"
def initialize(self, opt):
self.opt = opt
self.gpu_ids = self.opt.gpu_ids
self.device = torch.device('cuda:%d' % self.gpu_ids[0] if self.gpu_ids else 'cpu')
self.is_train = self.opt.mode == "train"
# inherit to define network model
self.models_name = []
def setup(self):
# print("%s with Model [%s]" % (self.opt.mode.capitalize(), self.name))
if self.is_train:
self.set_train()
# define loss function
self.criterionGAN = model_utils.GANLoss(gan_type=self.opt.gan_type).to(self.device)
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionMSE = torch.nn.MSELoss().to(self.device)
self.criterionTV = model_utils.TVLoss().to(self.device)
torch.nn.DataParallel(self.criterionGAN, self.gpu_ids)
torch.nn.DataParallel(self.criterionL1, self.gpu_ids)
torch.nn.DataParallel(self.criterionMSE, self.gpu_ids)
torch.nn.DataParallel(self.criterionTV, self.gpu_ids)
# inherit to set up train/val/test status
self.losses_name = []
self.optims = []
self.schedulers = []
else:
self.set_eval()
def set_eval(self):
print("Set model to Test state.")
for name in self.models_name:
if isinstance(name, str):
net = getattr(self, 'net_' + name)
if True:
net.eval()
print("Set net_%s to EVAL." % name)
else:
net.train()
self.is_train = False
def set_train(self):
print("Set model to Train state.")
for name in self.models_name:
if isinstance(name, str):
net = getattr(self, 'net_' + name)
net.train()
print("Set net_%s to TRAIN." % name)
self.is_train = True
def set_requires_grad(self, parameters, requires_grad=False):
if not isinstance(parameters, list):
parameters = [parameters]
for param in parameters:
if param is not None:
param.requires_grad = requires_grad
def get_latest_visuals(self, visuals_name):
visual_ret = OrderedDict()
for name in visuals_name:
if isinstance(name, str) and hasattr(self, name):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_latest_losses(self, losses_name):
errors_ret = OrderedDict()
for name in losses_name:
if isinstance(name, str):
cur_loss = float(getattr(self, 'loss_' + name))
# cur_loss_lambda = 1. if len(losses_name) == 1 else float(getattr(self.opt, 'lambda_' + name))
# errors_ret[name] = cur_loss * cur_loss_lambda
errors_ret[name] = cur_loss
return errors_ret
def feed_batch(self, batch):
pass
def forward(self):
pass
def optimize_paras(self):
pass
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optims[0].param_groups[0]['lr']
return lr
def save_ckpt(self, epoch, models_name):
for name in models_name:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.opt.ckpt_dir, save_filename)
net = getattr(self, 'net_' + name)
# save cpu params, so that it can be used in other GPU settings
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.to(self.gpu_ids[0])
net = torch.nn.DataParallel(net, self.gpu_ids)
else:
torch.save(net.cpu().state_dict(), save_path)
def load_ckpt(self, epoch, models_name):
# print(models_name)
for name in models_name:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
# load_path = os.path.join(self.opt.ckpt_dir, load_filename)
# assert os.path.isfile(load_path), "File '%s' does not exist." % load_path
# pretrained_state_dict = torch.load(load_path, map_location=str(self.device))
pretrained_state_dict = torch.load('checkpoints/30_net_gen.pth', map_location=str('cuda:0'))
if hasattr(pretrained_state_dict, '_metadata'):
del pretrained_state_dict._metadata
net = getattr(self, 'net_' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
# load only existing keys
pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k in net.state_dict()}
# for k, v in pretrained_state_dict.items():
# print(k)
# assert False
net.load_state_dict(pretrained_dict)
print("[Info] Successfully load trained weights for net_%s." % name)
def clean_ckpt(self, epoch, models_name):
for name in models_name:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.opt.ckpt_dir, load_filename)
if os.path.isfile(load_path):
os.remove(load_path)
def gradient_penalty(self, input_img, generate_img):
# interpolate sample
alpha = torch.rand(input_img.size(0), 1, 1, 1).to(self.device)
inter_img = (alpha * input_img.data + (1 - alpha) * generate_img.data).requires_grad_(True)
inter_img_prob, _ = self.net_dis(inter_img)
# computer gradient penalty: x: inter_img, y: inter_img_prob
# (L2_norm(dy/dx) - 1)**2
dydx = torch.autograd.grad(outputs=inter_img_prob,
inputs=inter_img,
grad_outputs=torch.ones(inter_img_prob.size()).to(self.device),
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
return torch.mean((dydx_l2norm - 1) ** 2)
|