Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from collections import OrderedDict | |
import numpy as np | |
from sklearn.metrics import roc_auc_score, roc_curve | |
from scipy.optimize import brentq | |
from scipy.interpolate import interp1d | |
# Tracking the path to the definition of the model. | |
MODELS_PATH = { | |
"Recce": "model/network/Recce.py" | |
} | |
def exp_recons_loss(recons, x): | |
x, y = x | |
loss = torch.tensor(0., device=y.device) | |
real_index = torch.where(1 - y)[0] | |
for r in recons: | |
if real_index.numel() > 0: | |
real_x = torch.index_select(x, dim=0, index=real_index) | |
real_rec = torch.index_select(r, dim=0, index=real_index) | |
real_rec = F.interpolate(real_rec, size=x.shape[-2:], mode='bilinear', align_corners=True) | |
loss += torch.mean(torch.abs(real_rec - real_x)) | |
return loss | |
def center_print(content, around='*', repeat_around=10): | |
num = repeat_around | |
s = around | |
print(num * s + ' %s ' % content + num * s) | |
def reduce_tensor(t): | |
rt = t.clone() | |
dist.all_reduce(rt) | |
rt /= float(dist.get_world_size()) | |
return rt | |
def tensor2image(tensor): | |
image = tensor.permute([1, 2, 0]).cpu().detach().numpy() | |
return (image - np.min(image)) / (np.max(image) - np.min(image)) | |
def state_dict(state_dict): | |
""" Remove 'module' keyword in state dictionary. """ | |
weights = OrderedDict() | |
for k, v in state_dict.items(): | |
weights.update({k.replace("module.", ""): v}) | |
return weights | |
class Logger(object): | |
def __init__(self, filename): | |
self.terminal = sys.stdout | |
self.log = open(filename, "a") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
self.log.flush() | |
def flush(self): | |
pass | |
class Timer(object): | |
"""The class for timer.""" | |
def __init__(self): | |
self.o = time.time() | |
def measure(self, p=1): | |
x = (time.time() - self.o) / p | |
x = int(x) | |
if x >= 3600: | |
return '{:.1f}h'.format(x / 3600) | |
if x >= 60: | |
return '{}m'.format(round(x / 60)) | |
return '{}s'.format(x) | |
class MLLoss(nn.Module): | |
def __init__(self): | |
super(MLLoss, self).__init__() | |
def forward(self, input, target, eps=1e-6): | |
# 0 - real; 1 - fake. | |
loss = torch.tensor(0., device=target.device) | |
batch_size = target.shape[0] | |
mat_1 = torch.hstack([target.unsqueeze(-1)] * batch_size) | |
mat_2 = torch.vstack([target] * batch_size) | |
diff_mat = torch.logical_xor(mat_1, mat_2).float() | |
or_mat = torch.logical_or(mat_1, mat_2) | |
eye = torch.eye(batch_size, device=target.device) | |
or_mat = torch.logical_or(or_mat, eye).float() | |
sim_mat = 1. - or_mat | |
for _ in input: | |
diff = torch.sum(_ * diff_mat, dim=[0, 1]) / (torch.sum(diff_mat, dim=[0, 1]) + eps) | |
sim = torch.sum(_ * sim_mat, dim=[0, 1]) / (torch.sum(sim_mat, dim=[0, 1]) + eps) | |
partial_loss = 1. - sim + diff | |
loss += max(partial_loss, torch.zeros_like(partial_loss)) | |
return loss | |
class AccMeter(object): | |
def __init__(self): | |
self.nums = 0 | |
self.acc = 0 | |
def reset(self): | |
self.nums = 0 | |
self.acc = 0 | |
def update(self, pred, target, use_bce=False): | |
if use_bce: | |
pred = (pred >= 0.5).int() | |
else: | |
pred = pred.argmax(1) | |
self.nums += target.shape[0] | |
self.acc += torch.sum(pred == target) | |
def mean_acc(self): | |
return self.acc / self.nums | |
class AUCMeter(object): | |
def __init__(self): | |
self.score = None | |
self.true = None | |
def reset(self): | |
self.score = None | |
self.true = None | |
def update(self, score, true, use_bce=False): | |
if use_bce: | |
score = score.detach().cpu().numpy() | |
else: | |
score = torch.softmax(score.detach(), dim=-1) | |
score = torch.select(score, 1, 1).cpu().numpy() | |
true = true.flatten().cpu().numpy() | |
self.score = score if self.score is None else np.concatenate([self.score, score]) | |
self.true = true if self.true is None else np.concatenate([self.true, true]) | |
def mean_auc(self): | |
return roc_auc_score(self.true, self.score) | |
def curve(self, prefix): | |
fpr, tpr, thresholds = roc_curve(self.true, self.score, pos_label=1) | |
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) | |
thresh = interp1d(fpr, thresholds)(eer) | |
print(f"# EER: {eer:.4f}(thresh: {thresh:.4f})") | |
torch.save([fpr, tpr, thresholds], os.path.join(prefix, "roc_curve.pickle")) | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |