Spaces:
Runtime error
Runtime error
File size: 4,423 Bytes
826d651 679d7d7 826d651 679d7d7 826d651 69de41b 826d651 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import logging
import numpy as np
import spaces
import torch
from torch.nn import init
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
net.apply(init_func) # apply the initialization function <init_func>
def create_logger(name, log_file, level=logging.INFO):
l = logging.getLogger(name)
formatter = logging.Formatter('[%(asctime)s] %(message)s')
fh = logging.FileHandler(log_file)
fh.setFormatter(formatter)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
l.setLevel(level)
l.addHandler(fh)
l.addHandler(sh)
return l
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, length=0):
self.length = length
self.reset()
def reset(self):
if self.length > 0:
self.history = []
else:
self.count = 0
self.sum = 0.0
self.val = 0.0
self.avg = 0.0
def update(self, val):
if self.length > 0:
self.history.append(val)
if len(self.history) > self.length:
del self.history[0]
self.val = self.history[-1]
self.avg = np.mean(self.history)
else:
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdims=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def load_state(path, model, optimizer=None):
def map_func(storage, location):
return storage.cuda()
if os.path.isfile(path):
print("=> loading checkpoint '{}'".format(path))
# checkpoint = torch.load(path, map_location=map_func)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'], strict=False)
ckpt_keys = set(checkpoint['state_dict'].keys())
own_keys = set(model.state_dict().keys())
missing_keys = own_keys - ckpt_keys
# print(ckpt_keys)
# print(own_keys)
for k in missing_keys:
print('caution: missing keys from checkpoint {}: {}'.format(path, k))
last_iter = checkpoint['step']
if optimizer != None:
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> also loaded optimizer from checkpoint '{}' (iter {})"
.format(path, last_iter))
return last_iter
else:
print("=> no checkpoint found at '{}'".format(path))
|