De-limiter / utils /logging.py
jeonchangbin49's picture
first commit
a00b67a
import os
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def save_img_and_npy(path, matrix):
plt.imsave(path + ".png", matrix, origin="lower")
def save_checkpoint(state, state_dict_only, path, target):
torch.save(state, os.path.join(path, target + ".chkpnt"))
if state_dict_only:
# save just the weights
torch.save(state["state_dict"], os.path.join(path, target + ".pth"))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EarlyStopping(object):
def __init__(self, mode="min", min_delta=0, patience=10):
self.mode = mode
self.min_delta = min_delta
self.patience = patience
self.best = None
self.num_bad_epochs = 0
self.is_better = None
self._init_is_better(mode, min_delta)
if patience == 0:
self.is_better = lambda a, b: True
def step(self, metrics):
if self.best is None:
self.best = metrics
return False
if np.isnan(metrics):
return True
if self.is_better(metrics, self.best):
self.num_bad_epochs = 0
self.best = metrics
else:
self.num_bad_epochs += 1
if self.num_bad_epochs >= self.patience:
return True
return False
def _init_is_better(self, mode, min_delta):
if mode not in {"min", "max"}:
raise ValueError("mode " + mode + " is unknown!")
if mode == "min":
self.is_better = lambda a, best: a < best - min_delta
if mode == "max":
self.is_better = lambda a, best: a > best + min_delta