import os import json import time import torch # import matplotlib.pyplot as plt import collections class TotalAverage(): def __init__(self): self.reset() def reset(self): self.last_value = 0. self.mass = 0. self.sum = 0. def update(self, value, mass=1): self.last_value = value self.mass += mass self.sum += value * mass def get(self): return self.sum / self.mass class MovingAverage(): def __init__(self, inertia=0.9): self.inertia = inertia self.reset() self.last_value = None def reset(self): self.last_value = None self.average = None def update(self, value, mass=1): self.last_value = value if self.average is None: self.average = value else: self.average = self.inertia * self.average + (1 - self.inertia) * value def get(self): return self.average class MetricsTrace: def __init__(self): self.data = {} self.reset() def reset(self): self.data = {} def append(self, dataset, metric): if dataset not in self.data: self.data[dataset] = [] self.data[dataset].append(metric.get_data_dict()) def load(self, path): """Load the metrics trace from the specified JSON file.""" with open(path, 'r') as f: self.data = json.load(f) def save(self, path): """Save the metrics trace to the specified JSON file.""" if path is None: return os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'w') as f: json.dump(self.data, f, indent=2) def plot(self, pdf_path=None): """Plots and optionally save as PDF the metrics trace.""" plot_metrics(self.data, pdf_path=pdf_path) def get(self): return self.data def __str__(self): pass class Metrics(): def __init__(self): self.iteration_time = MovingAverage(inertia=0.9) self.now = time.time() def update(self, prediction=None, ground_truth=None): self.iteration_time.update(time.time() - self.now) self.now = time.time() def get_data_dict(self): return {"objective" : self.objective.get(), "iteration_time" : self.iteration_time.get()} class StandardMetrics(Metrics): def __init__(self, m=None): super(StandardMetrics, self).__init__() self.metrics = m or {} self.speed = MovingAverage(inertia=0.9) def update(self, metric_dict, mass=1): super(StandardMetrics, self).update() for metric, val in metric_dict.items(): if torch.is_tensor(val): val = val.item() if metric not in self.metrics: if 'moving_average' in metric: try: p = float(metric.split('moving_average')[-1].split('_')[-1]) except: p = 0.9 self.metrics[metric] = MovingAverage(p) else: self.metrics[metric] = TotalAverage() self.metrics[metric].update(val, mass) self.speed.update(mass / self.iteration_time.last_value) def get_data_dict(self): data_dict = {k: v.get() for k,v in self.metrics.items()} data_dict['speed'] = self.speed.get() return data_dict def __str__(self): pstr = '%7.1fHz\t' %self.speed.get() pstr += '\t'.join(['%s: %6.5f' %(k,v.get()) for k,v in self.metrics.items()]) return pstr def plot_metrics(stats, pdf_path=None, fig=1, datasets=None, metrics=None): """Plot metrics. `stats` should be a dictionary of type stats[dataset][t][metric][i] where dataset is the dataset name (e.g. `train` or `val`), t is an iteration number, metric is the name of a metric (e.g. `loss` or `top1`), and i is a loss dimension. Alternatively, if a loss has a single dimension, `stats[dataset][t][metric]` can be a scalar. The supported options are: - pdf_file: path to a PDF file to store the figure (default: None) - fig: MatPlotLib figure index (default: 1) - datasets: list of dataset names to plot (default: None) - metrics: list of metrics to plot (default: None) """ plt.figure(fig) plt.clf() linestyles = ['-', '--', '-.', ':'] datasets = list(stats.keys()) if datasets is None else datasets # Filter out empty datasets datasets = [d for d in datasets if len(stats[d]) > 0] duration = len(stats[datasets[0]]) metrics = list(stats[datasets[0]][0].keys()) if metrics is None else metrics for m, metric in enumerate(metrics): plt.subplot(len(metrics),1,m+1) legend_content = [] for d, dataset in enumerate(datasets): ls = linestyles[d % len(linestyles)] if isinstance(stats[dataset][0][metric], collections.Iterable): metric_dimension = len(stats[dataset][0][metric]) for sl in range(metric_dimension): x = [stats[dataset][t][metric][sl] for t in range(duration)] plt.plot(x, linestyle=ls) name = f'{dataset} {metric}[{sl}]' legend_content.append(name) else: x = [stats[dataset][t][metric] for t in range(duration)] plt.plot(x, linestyle=ls) name = f'{dataset} {metric}' legend_content.append(name) plt.legend(legend_content, loc=(1.04,0)) plt.grid(True) if pdf_path is not None: plt.savefig(pdf_path, format='pdf', bbox_inches='tight') plt.draw() plt.pause(0.0001)