|
import os |
|
import logging |
|
import numpy as np |
|
|
|
import torch |
|
from torch.nn import init |
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02): |
|
"""Initialize network weights. |
|
Parameters: |
|
net (network) -- network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
work better for some applications. Feel free to try yourself. |
|
""" |
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, init_gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=init_gain) |
|
else: |
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.normal_(m.weight.data, 1.0, init_gain) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
net.apply(init_func) |
|
|
|
def create_logger(name, log_file, level=logging.INFO): |
|
l = logging.getLogger(name) |
|
formatter = logging.Formatter('[%(asctime)s] %(message)s') |
|
fh = logging.FileHandler(log_file) |
|
fh.setFormatter(formatter) |
|
sh = logging.StreamHandler() |
|
sh.setFormatter(formatter) |
|
l.setLevel(level) |
|
l.addHandler(fh) |
|
l.addHandler(sh) |
|
return l |
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self, length=0): |
|
self.length = length |
|
self.reset() |
|
|
|
def reset(self): |
|
if self.length > 0: |
|
self.history = [] |
|
else: |
|
self.count = 0 |
|
self.sum = 0.0 |
|
self.val = 0.0 |
|
self.avg = 0.0 |
|
|
|
def update(self, val): |
|
if self.length > 0: |
|
self.history.append(val) |
|
if len(self.history) > self.length: |
|
del self.history[0] |
|
|
|
self.val = self.history[-1] |
|
self.avg = np.mean(self.history) |
|
else: |
|
self.val = val |
|
self.sum += val |
|
self.count += 1 |
|
self.avg = self.sum / self.count |
|
|
|
def accuracy(output, target, topk=(1,)): |
|
"""Computes the precision@k for the specified values of k""" |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
|
res = [] |
|
for k in topk: |
|
correct_k = correct[:k].view(-1).float().sum(0, keepdims=True) |
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
return res |
|
|
|
def load_state(path, model, optimizer=None): |
|
def map_func(storage, location): |
|
return storage.cuda() |
|
if os.path.isfile(path): |
|
print("=> loading checkpoint '{}'".format(path)) |
|
checkpoint = torch.load(path, map_location=map_func) |
|
model.load_state_dict(checkpoint['state_dict'], strict=False) |
|
ckpt_keys = set(checkpoint['state_dict'].keys()) |
|
own_keys = set(model.state_dict().keys()) |
|
missing_keys = own_keys - ckpt_keys |
|
|
|
|
|
for k in missing_keys: |
|
print('caution: missing keys from checkpoint {}: {}'.format(path, k)) |
|
|
|
last_iter = checkpoint['step'] |
|
if optimizer != None: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
print("=> also loaded optimizer from checkpoint '{}' (iter {})" |
|
.format(path, last_iter)) |
|
return last_iter |
|
else: |
|
print("=> no checkpoint found at '{}'".format(path)) |
|
|
|
|
|
|