import os |
import logging |
import numpy as np |
import torch |
from torch.nn import init |
def init_weights(net, init_type='normal', init_gain=0.02): |
"""Initialize network weights. |
Parameters: |
net (network) -- network to be initialized |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
work better for some applications. Feel free to try yourself. |
""" |
def init_func(m): |
classname = m.__class__.__name__ |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
if init_type == 'normal': |
init.normal_(m.weight.data, 0.0, init_gain) |
elif init_type == 'xavier': |
init.xavier_normal_(m.weight.data, gain=init_gain) |
elif init_type == 'kaiming': |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
elif init_type == 'orthogonal': |
init.orthogonal_(m.weight.data, gain=init_gain) |
else: |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
if hasattr(m, 'bias') and m.bias is not None: |
init.constant_(m.bias.data, 0.0) |
elif classname.find('BatchNorm2d') != -1: |
init.normal_(m.weight.data, 1.0, init_gain) |
init.constant_(m.bias.data, 0.0) |
net.apply(init_func) |
def create_logger(name, log_file, level=logging.INFO): |
l = logging.getLogger(name) |
formatter = logging.Formatter('[%(asctime)s] %(message)s') |
fh = logging.FileHandler(log_file) |
fh.setFormatter(formatter) |
sh = logging.StreamHandler() |
sh.setFormatter(formatter) |
l.setLevel(level) |
l.addHandler(fh) |
l.addHandler(sh) |
return l |
class AverageMeter(object): |
"""Computes and stores the average and current value""" |
def __init__(self, length=0): |
self.length = length |
self.reset() |
def reset(self): |
if self.length > 0: |
self.history = [] |
else: |
self.count = 0 |
self.sum = 0.0 |
self.val = 0.0 |
self.avg = 0.0 |
def update(self, val): |
if self.length > 0: |
self.history.append(val) |
if len(self.history) > self.length: |
del self.history[0] |
self.val = self.history[-1] |
self.avg = np.mean(self.history) |
else: |
self.val = val |
self.sum += val |
self.count += 1 |
self.avg = self.sum / self.count |
def accuracy(output, target, topk=(1,)): |
"""Computes the precision@k for the specified values of k""" |
maxk = max(topk) |
batch_size = target.size(0) |
_, pred = output.topk(maxk, 1, True, True) |
pred = pred.t() |
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
res = [] |
for k in topk: |
correct_k = correct[:k].view(-1).float().sum(0, keepdims=True) |
res.append(correct_k.mul_(100.0 / batch_size)) |
return res |
def load_state(path, model, optimizer=None): |
def map_func(storage, location): |
return storage.cuda() |
if os.path.isfile(path): |
print("=> loading checkpoint '{}'".format(path)) |
checkpoint = torch.load(path, map_location=map_func) |
model.load_state_dict(checkpoint['state_dict'], strict=False) |
ckpt_keys = set(checkpoint['state_dict'].keys()) |
own_keys = set(model.state_dict().keys()) |
missing_keys = own_keys - ckpt_keys |
for k in missing_keys: |
print('caution: missing keys from checkpoint {}: {}'.format(path, k)) |
last_iter = checkpoint['step'] |
if optimizer != None: |
optimizer.load_state_dict(checkpoint['optimizer']) |
print("=> also loaded optimizer from checkpoint '{}' (iter {})" |
.format(path, last_iter)) |
return last_iter |
else: |
print("=> no checkpoint found at '{}'".format(path)) |