GeoCalib / siclib /utils /tools.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
8.51 kB
"""
Various handy Python and PyTorch utils.
Author: Paul-Edouard Sarlin (skydes)
"""
import os
import random
import time
from collections.abc import Iterable
from contextlib import contextmanager
from typing import Optional
import numpy as np
import torch
# flake8: noqa
# mypy: ignore-errors
class AverageMetric:
def __init__(self, elements=None):
if elements is None:
elements = []
self._sum = 0
self._num_examples = 0
else:
mask = ~np.isnan(elements)
self._sum = sum(elements[mask])
self._num_examples = len(elements[mask])
def update(self, tensor):
assert tensor.dim() == 1, tensor.shape
tensor = tensor[~torch.isnan(tensor)]
self._sum += tensor.sum().item()
self._num_examples += len(tensor)
def compute(self):
return np.nan if self._num_examples == 0 else self._sum / self._num_examples
# same as AverageMetric, but tracks all elements
class FAverageMetric:
def __init__(self):
self._sum = 0
self._num_examples = 0
self._elements = []
def update(self, tensor):
self._elements += tensor.cpu().numpy().tolist()
assert tensor.dim() == 1, tensor.shape
tensor = tensor[~torch.isnan(tensor)]
self._sum += tensor.sum().item()
self._num_examples += len(tensor)
def compute(self):
return np.nan if self._num_examples == 0 else self._sum / self._num_examples
class MedianMetric:
def __init__(self, elements=None):
if elements is None:
elements = []
self._elements = elements
def update(self, tensor):
assert tensor.dim() == 1, tensor.shape
self._elements += tensor.cpu().numpy().tolist()
def compute(self):
if len(self._elements) == 0:
return np.nan
# set nan to inf to avoid error
self._elements = np.array(self._elements)
self._elements[np.isnan(self._elements)] = np.inf
return np.nanmedian(self._elements)
class PRMetric:
def __init__(self):
self.labels = []
self.predictions = []
@torch.no_grad()
def update(self, labels, predictions, mask=None):
assert labels.shape == predictions.shape
self.labels += (labels[mask] if mask is not None else labels).cpu().numpy().tolist()
self.predictions += (
(predictions[mask] if mask is not None else predictions).cpu().numpy().tolist()
)
@torch.no_grad()
def compute(self):
return np.array(self.labels), np.array(self.predictions)
def reset(self):
self.labels = []
self.predictions = []
class QuantileMetric:
def __init__(self, q=0.05):
self._elements = []
self.q = q
def update(self, tensor):
assert tensor.dim() == 1
self._elements += tensor.cpu().numpy().tolist()
def compute(self):
if len(self._elements) == 0:
return np.nan
else:
return np.nanquantile(self._elements, self.q)
class RecallMetric:
def __init__(self, ths, elements=None):
if elements is None:
elements = []
self._elements = elements
self.ths = ths
def update(self, tensor):
assert tensor.dim() == 1, tensor.shape
self._elements += tensor.cpu().numpy().tolist()
def compute(self):
# set nan to inf to avoid error
self._elements = np.array(self._elements)
self._elements[np.isnan(self._elements)] = np.inf
if isinstance(self.ths, Iterable):
return [self.compute_(th) for th in self.ths]
else:
return self.compute_(self.ths[0])
def compute_(self, th):
if len(self._elements) == 0:
return np.nan
s = (np.array(self._elements) < th).sum()
return s / len(self._elements)
def compute_recall(errors):
num_elements = len(errors)
sort_idx = np.argsort(errors)
errors = np.array(errors.copy())[sort_idx]
recall = (np.arange(num_elements) + 1) / num_elements
return errors, recall
def compute_auc(errors, thresholds, min_error: Optional[float] = None):
errors, recall = compute_recall(errors)
if min_error is not None:
min_index = np.searchsorted(errors, min_error, side="right")
min_score = min_index / len(errors)
recall = np.r_[min_score, min_score, recall[min_index:]]
errors = np.r_[0, min_error, errors[min_index:]]
else:
recall = np.r_[0, recall]
errors = np.r_[0, errors]
aucs = []
for t in thresholds:
last_index = np.searchsorted(errors, t, side="right")
r = np.r_[recall[:last_index], recall[last_index - 1]]
e = np.r_[errors[:last_index], t]
auc = np.trapz(r, x=e) / t
aucs.append(np.round(auc, 4))
return aucs
class AUCMetric:
def __init__(self, thresholds, elements=None, min_error: Optional[float] = None):
self._elements = elements
self.thresholds = thresholds
self.min_error = min_error
if not isinstance(thresholds, list):
self.thresholds = [thresholds]
def update(self, tensor):
assert tensor.dim() == 1, tensor.shape
self._elements += tensor.cpu().numpy().tolist()
def compute(self):
if len(self._elements) == 0:
return np.nan
# set nan to inf to avoid error
self._elements = np.array(self._elements)
self._elements[np.isnan(self._elements)] = np.inf
return compute_auc(self._elements, self.thresholds, self.min_error)
class Timer(object):
"""A simpler timer context object.
Usage:
```
> with Timer('mytimer'):
> # some computations
[mytimer] Elapsed: X
```
"""
def __init__(self, name=None):
self.name = name
def __enter__(self):
self.tstart = time.time()
return self
def __exit__(self, type, value, traceback):
self.duration = time.time() - self.tstart
if self.name is not None:
print(f"[{self.name}] Elapsed: {self.duration}")
def get_class(mod_path, BaseClass):
"""Get the class object which inherits from BaseClass and is defined in
the module named mod_name, child of base_path.
"""
import inspect
mod = __import__(mod_path, fromlist=[""])
classes = inspect.getmembers(mod, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == mod_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseClass)]
assert len(classes) == 1, classes
return classes[0][1]
def set_num_threads(nt):
"""Force numpy and other libraries to use a limited number of threads."""
try:
import mkl # type: ignore
except ImportError:
pass
else:
mkl.set_num_threads(nt)
torch.set_num_threads(1)
os.environ["IPC_ENABLE"] = "1"
for o in [
"OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
]:
os.environ[o] = str(nt)
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_random_state(with_cuda):
pth_state = torch.get_rng_state()
np_state = np.random.get_state()
py_state = random.getstate()
if torch.cuda.is_available() and with_cuda:
cuda_state = torch.cuda.get_rng_state_all()
else:
cuda_state = None
return pth_state, np_state, py_state, cuda_state
def set_random_state(state):
pth_state, np_state, py_state, cuda_state = state
torch.set_rng_state(pth_state)
np.random.set_state(np_state)
random.setstate(py_state)
if (
cuda_state is not None
and torch.cuda.is_available()
and len(cuda_state) == torch.cuda.device_count()
):
torch.cuda.set_rng_state_all(cuda_state)
@contextmanager
def fork_rng(seed=None, with_cuda=True):
state = get_random_state(with_cuda)
if seed is not None:
set_seed(seed)
try:
yield
finally:
set_random_state(state)
def get_device() -> str:
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
return device