|
""" |
|
Various handy Python and PyTorch utils. |
|
|
|
Author: Paul-Edouard Sarlin (skydes) |
|
""" |
|
|
|
import os |
|
import random |
|
import time |
|
from collections.abc import Iterable |
|
from contextlib import contextmanager |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class AverageMetric: |
|
def __init__(self, elements=None): |
|
if elements is None: |
|
elements = [] |
|
self._sum = 0 |
|
self._num_examples = 0 |
|
else: |
|
mask = ~np.isnan(elements) |
|
self._sum = sum(elements[mask]) |
|
self._num_examples = len(elements[mask]) |
|
|
|
def update(self, tensor): |
|
assert tensor.dim() == 1, tensor.shape |
|
tensor = tensor[~torch.isnan(tensor)] |
|
self._sum += tensor.sum().item() |
|
self._num_examples += len(tensor) |
|
|
|
def compute(self): |
|
return np.nan if self._num_examples == 0 else self._sum / self._num_examples |
|
|
|
|
|
|
|
class FAverageMetric: |
|
def __init__(self): |
|
self._sum = 0 |
|
self._num_examples = 0 |
|
self._elements = [] |
|
|
|
def update(self, tensor): |
|
self._elements += tensor.cpu().numpy().tolist() |
|
assert tensor.dim() == 1, tensor.shape |
|
tensor = tensor[~torch.isnan(tensor)] |
|
self._sum += tensor.sum().item() |
|
self._num_examples += len(tensor) |
|
|
|
def compute(self): |
|
return np.nan if self._num_examples == 0 else self._sum / self._num_examples |
|
|
|
|
|
class MedianMetric: |
|
def __init__(self, elements=None): |
|
if elements is None: |
|
elements = [] |
|
|
|
self._elements = elements |
|
|
|
def update(self, tensor): |
|
assert tensor.dim() == 1, tensor.shape |
|
self._elements += tensor.cpu().numpy().tolist() |
|
|
|
def compute(self): |
|
if len(self._elements) == 0: |
|
return np.nan |
|
|
|
|
|
self._elements = np.array(self._elements) |
|
self._elements[np.isnan(self._elements)] = np.inf |
|
return np.nanmedian(self._elements) |
|
|
|
|
|
class PRMetric: |
|
def __init__(self): |
|
self.labels = [] |
|
self.predictions = [] |
|
|
|
@torch.no_grad() |
|
def update(self, labels, predictions, mask=None): |
|
assert labels.shape == predictions.shape |
|
self.labels += (labels[mask] if mask is not None else labels).cpu().numpy().tolist() |
|
self.predictions += ( |
|
(predictions[mask] if mask is not None else predictions).cpu().numpy().tolist() |
|
) |
|
|
|
@torch.no_grad() |
|
def compute(self): |
|
return np.array(self.labels), np.array(self.predictions) |
|
|
|
def reset(self): |
|
self.labels = [] |
|
self.predictions = [] |
|
|
|
|
|
class QuantileMetric: |
|
def __init__(self, q=0.05): |
|
self._elements = [] |
|
self.q = q |
|
|
|
def update(self, tensor): |
|
assert tensor.dim() == 1 |
|
self._elements += tensor.cpu().numpy().tolist() |
|
|
|
def compute(self): |
|
if len(self._elements) == 0: |
|
return np.nan |
|
else: |
|
return np.nanquantile(self._elements, self.q) |
|
|
|
|
|
class RecallMetric: |
|
def __init__(self, ths, elements=None): |
|
if elements is None: |
|
elements = [] |
|
|
|
self._elements = elements |
|
self.ths = ths |
|
|
|
def update(self, tensor): |
|
assert tensor.dim() == 1, tensor.shape |
|
self._elements += tensor.cpu().numpy().tolist() |
|
|
|
def compute(self): |
|
|
|
self._elements = np.array(self._elements) |
|
self._elements[np.isnan(self._elements)] = np.inf |
|
|
|
if isinstance(self.ths, Iterable): |
|
return [self.compute_(th) for th in self.ths] |
|
else: |
|
return self.compute_(self.ths[0]) |
|
|
|
def compute_(self, th): |
|
if len(self._elements) == 0: |
|
return np.nan |
|
|
|
s = (np.array(self._elements) < th).sum() |
|
return s / len(self._elements) |
|
|
|
|
|
def compute_recall(errors): |
|
num_elements = len(errors) |
|
sort_idx = np.argsort(errors) |
|
errors = np.array(errors.copy())[sort_idx] |
|
recall = (np.arange(num_elements) + 1) / num_elements |
|
return errors, recall |
|
|
|
|
|
def compute_auc(errors, thresholds, min_error: Optional[float] = None): |
|
errors, recall = compute_recall(errors) |
|
|
|
if min_error is not None: |
|
min_index = np.searchsorted(errors, min_error, side="right") |
|
min_score = min_index / len(errors) |
|
recall = np.r_[min_score, min_score, recall[min_index:]] |
|
errors = np.r_[0, min_error, errors[min_index:]] |
|
else: |
|
recall = np.r_[0, recall] |
|
errors = np.r_[0, errors] |
|
|
|
aucs = [] |
|
for t in thresholds: |
|
last_index = np.searchsorted(errors, t, side="right") |
|
r = np.r_[recall[:last_index], recall[last_index - 1]] |
|
e = np.r_[errors[:last_index], t] |
|
auc = np.trapz(r, x=e) / t |
|
aucs.append(np.round(auc, 4)) |
|
return aucs |
|
|
|
|
|
class AUCMetric: |
|
def __init__(self, thresholds, elements=None, min_error: Optional[float] = None): |
|
self._elements = elements |
|
self.thresholds = thresholds |
|
self.min_error = min_error |
|
if not isinstance(thresholds, list): |
|
self.thresholds = [thresholds] |
|
|
|
def update(self, tensor): |
|
assert tensor.dim() == 1, tensor.shape |
|
self._elements += tensor.cpu().numpy().tolist() |
|
|
|
def compute(self): |
|
if len(self._elements) == 0: |
|
return np.nan |
|
|
|
|
|
self._elements = np.array(self._elements) |
|
self._elements[np.isnan(self._elements)] = np.inf |
|
return compute_auc(self._elements, self.thresholds, self.min_error) |
|
|
|
|
|
class Timer(object): |
|
"""A simpler timer context object. |
|
Usage: |
|
``` |
|
> with Timer('mytimer'): |
|
> # some computations |
|
[mytimer] Elapsed: X |
|
``` |
|
""" |
|
|
|
def __init__(self, name=None): |
|
self.name = name |
|
|
|
def __enter__(self): |
|
self.tstart = time.time() |
|
return self |
|
|
|
def __exit__(self, type, value, traceback): |
|
self.duration = time.time() - self.tstart |
|
if self.name is not None: |
|
print(f"[{self.name}] Elapsed: {self.duration}") |
|
|
|
|
|
def get_class(mod_path, BaseClass): |
|
"""Get the class object which inherits from BaseClass and is defined in |
|
the module named mod_name, child of base_path. |
|
""" |
|
import inspect |
|
|
|
mod = __import__(mod_path, fromlist=[""]) |
|
classes = inspect.getmembers(mod, inspect.isclass) |
|
|
|
classes = [c for c in classes if c[1].__module__ == mod_path] |
|
|
|
classes = [c for c in classes if issubclass(c[1], BaseClass)] |
|
assert len(classes) == 1, classes |
|
return classes[0][1] |
|
|
|
|
|
def set_num_threads(nt): |
|
"""Force numpy and other libraries to use a limited number of threads.""" |
|
try: |
|
import mkl |
|
except ImportError: |
|
pass |
|
else: |
|
mkl.set_num_threads(nt) |
|
torch.set_num_threads(1) |
|
os.environ["IPC_ENABLE"] = "1" |
|
for o in [ |
|
"OPENBLAS_NUM_THREADS", |
|
"NUMEXPR_NUM_THREADS", |
|
"OMP_NUM_THREADS", |
|
"MKL_NUM_THREADS", |
|
]: |
|
os.environ[o] = str(nt) |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def get_random_state(with_cuda): |
|
pth_state = torch.get_rng_state() |
|
np_state = np.random.get_state() |
|
py_state = random.getstate() |
|
if torch.cuda.is_available() and with_cuda: |
|
cuda_state = torch.cuda.get_rng_state_all() |
|
else: |
|
cuda_state = None |
|
return pth_state, np_state, py_state, cuda_state |
|
|
|
|
|
def set_random_state(state): |
|
pth_state, np_state, py_state, cuda_state = state |
|
torch.set_rng_state(pth_state) |
|
np.random.set_state(np_state) |
|
random.setstate(py_state) |
|
if ( |
|
cuda_state is not None |
|
and torch.cuda.is_available() |
|
and len(cuda_state) == torch.cuda.device_count() |
|
): |
|
torch.cuda.set_rng_state_all(cuda_state) |
|
|
|
|
|
@contextmanager |
|
def fork_rng(seed=None, with_cuda=True): |
|
state = get_random_state(with_cuda) |
|
if seed is not None: |
|
set_seed(seed) |
|
try: |
|
yield |
|
finally: |
|
set_random_state(state) |
|
|
|
|
|
def get_device() -> str: |
|
device = "cpu" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
return device |
|
|