P-DFD / trainer /utils.py
mrneuralnet's picture
Initial commit
982865f
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from collections import OrderedDict
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
# Tracking the path to the definition of the model.
MODELS_PATH = {
"Recce": "model/network/Recce.py"
}
def exp_recons_loss(recons, x):
x, y = x
loss = torch.tensor(0., device=y.device)
real_index = torch.where(1 - y)[0]
for r in recons:
if real_index.numel() > 0:
real_x = torch.index_select(x, dim=0, index=real_index)
real_rec = torch.index_select(r, dim=0, index=real_index)
real_rec = F.interpolate(real_rec, size=x.shape[-2:], mode='bilinear', align_corners=True)
loss += torch.mean(torch.abs(real_rec - real_x))
return loss
def center_print(content, around='*', repeat_around=10):
num = repeat_around
s = around
print(num * s + ' %s ' % content + num * s)
def reduce_tensor(t):
rt = t.clone()
dist.all_reduce(rt)
rt /= float(dist.get_world_size())
return rt
def tensor2image(tensor):
image = tensor.permute([1, 2, 0]).cpu().detach().numpy()
return (image - np.min(image)) / (np.max(image) - np.min(image))
def state_dict(state_dict):
""" Remove 'module' keyword in state dictionary. """
weights = OrderedDict()
for k, v in state_dict.items():
weights.update({k.replace("module.", ""): v})
return weights
class Logger(object):
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
pass
class Timer(object):
"""The class for timer."""
def __init__(self):
self.o = time.time()
def measure(self, p=1):
x = (time.time() - self.o) / p
x = int(x)
if x >= 3600:
return '{:.1f}h'.format(x / 3600)
if x >= 60:
return '{}m'.format(round(x / 60))
return '{}s'.format(x)
class MLLoss(nn.Module):
def __init__(self):
super(MLLoss, self).__init__()
def forward(self, input, target, eps=1e-6):
# 0 - real; 1 - fake.
loss = torch.tensor(0., device=target.device)
batch_size = target.shape[0]
mat_1 = torch.hstack([target.unsqueeze(-1)] * batch_size)
mat_2 = torch.vstack([target] * batch_size)
diff_mat = torch.logical_xor(mat_1, mat_2).float()
or_mat = torch.logical_or(mat_1, mat_2)
eye = torch.eye(batch_size, device=target.device)
or_mat = torch.logical_or(or_mat, eye).float()
sim_mat = 1. - or_mat
for _ in input:
diff = torch.sum(_ * diff_mat, dim=[0, 1]) / (torch.sum(diff_mat, dim=[0, 1]) + eps)
sim = torch.sum(_ * sim_mat, dim=[0, 1]) / (torch.sum(sim_mat, dim=[0, 1]) + eps)
partial_loss = 1. - sim + diff
loss += max(partial_loss, torch.zeros_like(partial_loss))
return loss
class AccMeter(object):
def __init__(self):
self.nums = 0
self.acc = 0
def reset(self):
self.nums = 0
self.acc = 0
def update(self, pred, target, use_bce=False):
if use_bce:
pred = (pred >= 0.5).int()
else:
pred = pred.argmax(1)
self.nums += target.shape[0]
self.acc += torch.sum(pred == target)
def mean_acc(self):
return self.acc / self.nums
class AUCMeter(object):
def __init__(self):
self.score = None
self.true = None
def reset(self):
self.score = None
self.true = None
def update(self, score, true, use_bce=False):
if use_bce:
score = score.detach().cpu().numpy()
else:
score = torch.softmax(score.detach(), dim=-1)
score = torch.select(score, 1, 1).cpu().numpy()
true = true.flatten().cpu().numpy()
self.score = score if self.score is None else np.concatenate([self.score, score])
self.true = true if self.true is None else np.concatenate([self.true, true])
def mean_auc(self):
return roc_auc_score(self.true, self.score)
def curve(self, prefix):
fpr, tpr, thresholds = roc_curve(self.true, self.score, pos_label=1)
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
thresh = interp1d(fpr, thresholds)(eer)
print(f"# EER: {eer:.4f}(thresh: {thresh:.4f})")
torch.save([fpr, tpr, thresholds], os.path.join(prefix, "roc_curve.pickle"))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count