import os import sys import time import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from collections import OrderedDict import numpy as np from sklearn.metrics import roc_auc_score, roc_curve from scipy.optimize import brentq from scipy.interpolate import interp1d # Tracking the path to the definition of the model. MODELS_PATH = { "Recce": "model/network/Recce.py" } def exp_recons_loss(recons, x): x, y = x loss = torch.tensor(0., device=y.device) real_index = torch.where(1 - y)[0] for r in recons: if real_index.numel() > 0: real_x = torch.index_select(x, dim=0, index=real_index) real_rec = torch.index_select(r, dim=0, index=real_index) real_rec = F.interpolate(real_rec, size=x.shape[-2:], mode='bilinear', align_corners=True) loss += torch.mean(torch.abs(real_rec - real_x)) return loss def center_print(content, around='*', repeat_around=10): num = repeat_around s = around print(num * s + ' %s ' % content + num * s) def reduce_tensor(t): rt = t.clone() dist.all_reduce(rt) rt /= float(dist.get_world_size()) return rt def tensor2image(tensor): image = tensor.permute([1, 2, 0]).cpu().detach().numpy() return (image - np.min(image)) / (np.max(image) - np.min(image)) def state_dict(state_dict): """ Remove 'module' keyword in state dictionary. """ weights = OrderedDict() for k, v in state_dict.items(): weights.update({k.replace("module.", ""): v}) return weights class Logger(object): def __init__(self, filename): self.terminal = sys.stdout self.log = open(filename, "a") def write(self, message): self.terminal.write(message) self.log.write(message) self.log.flush() def flush(self): pass class Timer(object): """The class for timer.""" def __init__(self): self.o = time.time() def measure(self, p=1): x = (time.time() - self.o) / p x = int(x) if x >= 3600: return '{:.1f}h'.format(x / 3600) if x >= 60: return '{}m'.format(round(x / 60)) return '{}s'.format(x) class MLLoss(nn.Module): def __init__(self): super(MLLoss, self).__init__() def forward(self, input, target, eps=1e-6): # 0 - real; 1 - fake. loss = torch.tensor(0., device=target.device) batch_size = target.shape[0] mat_1 = torch.hstack([target.unsqueeze(-1)] * batch_size) mat_2 = torch.vstack([target] * batch_size) diff_mat = torch.logical_xor(mat_1, mat_2).float() or_mat = torch.logical_or(mat_1, mat_2) eye = torch.eye(batch_size, device=target.device) or_mat = torch.logical_or(or_mat, eye).float() sim_mat = 1. - or_mat for _ in input: diff = torch.sum(_ * diff_mat, dim=[0, 1]) / (torch.sum(diff_mat, dim=[0, 1]) + eps) sim = torch.sum(_ * sim_mat, dim=[0, 1]) / (torch.sum(sim_mat, dim=[0, 1]) + eps) partial_loss = 1. - sim + diff loss += max(partial_loss, torch.zeros_like(partial_loss)) return loss class AccMeter(object): def __init__(self): self.nums = 0 self.acc = 0 def reset(self): self.nums = 0 self.acc = 0 def update(self, pred, target, use_bce=False): if use_bce: pred = (pred >= 0.5).int() else: pred = pred.argmax(1) self.nums += target.shape[0] self.acc += torch.sum(pred == target) def mean_acc(self): return self.acc / self.nums class AUCMeter(object): def __init__(self): self.score = None self.true = None def reset(self): self.score = None self.true = None def update(self, score, true, use_bce=False): if use_bce: score = score.detach().cpu().numpy() else: score = torch.softmax(score.detach(), dim=-1) score = torch.select(score, 1, 1).cpu().numpy() true = true.flatten().cpu().numpy() self.score = score if self.score is None else np.concatenate([self.score, score]) self.true = true if self.true is None else np.concatenate([self.true, true]) def mean_auc(self): return roc_auc_score(self.true, self.score) def curve(self, prefix): fpr, tpr, thresholds = roc_curve(self.true, self.score, pos_label=1) eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) thresh = interp1d(fpr, thresholds)(eer) print(f"# EER: {eer:.4f}(thresh: {thresh:.4f})") torch.save([fpr, tpr, thresholds], os.path.join(prefix, "roc_curve.pickle")) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count