Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
def save_img_and_npy(path, matrix): | |
plt.imsave(path + ".png", matrix, origin="lower") | |
def save_checkpoint(state, state_dict_only, path, target): | |
torch.save(state, os.path.join(path, target + ".chkpnt")) | |
if state_dict_only: | |
# save just the weights | |
torch.save(state["state_dict"], os.path.join(path, target + ".pth")) | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
class EarlyStopping(object): | |
def __init__(self, mode="min", min_delta=0, patience=10): | |
self.mode = mode | |
self.min_delta = min_delta | |
self.patience = patience | |
self.best = None | |
self.num_bad_epochs = 0 | |
self.is_better = None | |
self._init_is_better(mode, min_delta) | |
if patience == 0: | |
self.is_better = lambda a, b: True | |
def step(self, metrics): | |
if self.best is None: | |
self.best = metrics | |
return False | |
if np.isnan(metrics): | |
return True | |
if self.is_better(metrics, self.best): | |
self.num_bad_epochs = 0 | |
self.best = metrics | |
else: | |
self.num_bad_epochs += 1 | |
if self.num_bad_epochs >= self.patience: | |
return True | |
return False | |
def _init_is_better(self, mode, min_delta): | |
if mode not in {"min", "max"}: | |
raise ValueError("mode " + mode + " is unknown!") | |
if mode == "min": | |
self.is_better = lambda a, best: a < best - min_delta | |
if mode == "max": | |
self.is_better = lambda a, best: a > best + min_delta | |