|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
import dataclasses |
|
import datetime |
|
from distutils.version import LooseVersion |
|
import logging |
|
from pathlib import Path |
|
import time |
|
from typing import ContextManager |
|
from typing import Dict |
|
from typing import List |
|
from typing import Optional |
|
from typing import Sequence |
|
from typing import Tuple |
|
from typing import Union |
|
import warnings |
|
|
|
import humanfriendly |
|
import numpy as np |
|
import torch |
|
from typeguard import check_argument_types |
|
from typeguard import check_return_type |
|
import wandb |
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): |
|
from torch.utils.tensorboard import SummaryWriter |
|
else: |
|
from tensorboardX import SummaryWriter |
|
|
|
Num = Union[float, int, complex, torch.Tensor, np.ndarray] |
|
|
|
|
|
_reserved = {"time", "total_count"} |
|
|
|
|
|
def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue": |
|
assert check_argument_types() |
|
if isinstance(v, (torch.Tensor, np.ndarray)): |
|
if np.prod(v.shape) != 1: |
|
raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}") |
|
v = v.item() |
|
|
|
if isinstance(weight, (torch.Tensor, np.ndarray)): |
|
if np.prod(weight.shape) != 1: |
|
raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}") |
|
weight = weight.item() |
|
|
|
if weight is not None: |
|
retval = WeightedAverage(v, weight) |
|
else: |
|
retval = Average(v) |
|
assert check_return_type(retval) |
|
return retval |
|
|
|
|
|
def aggregate(values: Sequence["ReportedValue"]) -> Num: |
|
assert check_argument_types() |
|
|
|
for v in values: |
|
if not isinstance(v, type(values[0])): |
|
raise ValueError( |
|
f"Can't use different Reported type together: " |
|
f"{type(v)} != {type(values[0])}" |
|
) |
|
|
|
if len(values) == 0: |
|
warnings.warn("No stats found") |
|
retval = np.nan |
|
|
|
elif isinstance(values[0], Average): |
|
retval = np.nanmean([v.value for v in values]) |
|
|
|
elif isinstance(values[0], WeightedAverage): |
|
|
|
invalid_indices = set() |
|
for i, v in enumerate(values): |
|
if not np.isfinite(v.value) or not np.isfinite(v.weight): |
|
invalid_indices.add(i) |
|
values = [v for i, v in enumerate(values) if i not in invalid_indices] |
|
|
|
if len(values) != 0: |
|
|
|
sum_weights = sum(v.weight for i, v in enumerate(values)) |
|
sum_value = sum(v.value * v.weight for i, v in enumerate(values)) |
|
if sum_weights == 0: |
|
warnings.warn("weight is zero") |
|
retval = np.nan |
|
else: |
|
retval = sum_value / sum_weights |
|
else: |
|
warnings.warn("No valid stats found") |
|
retval = np.nan |
|
|
|
else: |
|
raise NotImplementedError(f"type={type(values[0])}") |
|
assert check_return_type(retval) |
|
return retval |
|
|
|
|
|
class ReportedValue: |
|
pass |
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
class Average(ReportedValue): |
|
value: Num |
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
class WeightedAverage(ReportedValue): |
|
value: Tuple[Num, Num] |
|
weight: Num |
|
|
|
|
|
class SubReporter: |
|
"""This class is used in Reporter. |
|
|
|
See the docstring of Reporter for the usage. |
|
""" |
|
|
|
def __init__(self, key: str, epoch: int, total_count: int): |
|
assert check_argument_types() |
|
self.key = key |
|
self.epoch = epoch |
|
self.start_time = time.perf_counter() |
|
self.stats = defaultdict(list) |
|
self._finished = False |
|
self.total_count = total_count |
|
self.count = 0 |
|
self._seen_keys_in_the_step = set() |
|
|
|
def get_total_count(self) -> int: |
|
"""Returns the number of iterations over all epochs.""" |
|
return self.total_count |
|
|
|
def get_epoch(self) -> int: |
|
return self.epoch |
|
|
|
def next(self): |
|
"""Close up this step and reset state for the next step""" |
|
for key, stats_list in self.stats.items(): |
|
if key not in self._seen_keys_in_the_step: |
|
|
|
if isinstance(stats_list[0], WeightedAverage): |
|
stats_list.append(to_reported_value(np.nan, 0)) |
|
elif isinstance(stats_list[0], Average): |
|
stats_list.append(to_reported_value(np.nan)) |
|
else: |
|
raise NotImplementedError(f"type={type(stats_list[0])}") |
|
|
|
assert len(stats_list) == self.count, (len(stats_list), self.count) |
|
|
|
self._seen_keys_in_the_step = set() |
|
|
|
def register( |
|
self, |
|
stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]], |
|
weight: Num = None, |
|
) -> None: |
|
assert check_argument_types() |
|
if self._finished: |
|
raise RuntimeError("Already finished") |
|
if len(self._seen_keys_in_the_step) == 0: |
|
|
|
self.total_count += 1 |
|
self.count += 1 |
|
|
|
for key2, v in stats.items(): |
|
if key2 in _reserved: |
|
raise RuntimeError(f"{key2} is reserved.") |
|
if key2 in self._seen_keys_in_the_step: |
|
raise RuntimeError(f"{key2} is registered twice.") |
|
if v is None: |
|
v = np.nan |
|
r = to_reported_value(v, weight) |
|
|
|
if key2 not in self.stats: |
|
|
|
|
|
|
|
|
|
|
|
|
|
nan = to_reported_value(np.nan, None if weight is None else 0) |
|
self.stats[key2].extend( |
|
r if i == self.count - 1 else nan for i in range(self.count) |
|
) |
|
else: |
|
self.stats[key2].append(r) |
|
self._seen_keys_in_the_step.add(key2) |
|
|
|
def log_message(self, start: int = None, end: int = None) -> str: |
|
if self._finished: |
|
raise RuntimeError("Already finished") |
|
if start is None: |
|
start = 0 |
|
if start < 0: |
|
start = self.count + start |
|
if end is None: |
|
end = self.count |
|
|
|
if self.count == 0 or start == end: |
|
return "" |
|
|
|
message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch: " |
|
|
|
for idx, (key2, stats_list) in enumerate(self.stats.items()): |
|
assert len(stats_list) == self.count, (len(stats_list), self.count) |
|
|
|
values = stats_list[start:end] |
|
if idx != 0 and idx != len(stats_list): |
|
message += ", " |
|
|
|
v = aggregate(values) |
|
if abs(v) > 1.0e3: |
|
message += f"{key2}={v:.3e}" |
|
elif abs(v) > 1.0e-3: |
|
message += f"{key2}={v:.3f}" |
|
else: |
|
message += f"{key2}={v:.3e}" |
|
return message |
|
|
|
def tensorboard_add_scalar(self, summary_writer: SummaryWriter, start: int = None): |
|
if start is None: |
|
start = 0 |
|
if start < 0: |
|
start = self.count + start |
|
|
|
for key2, stats_list in self.stats.items(): |
|
assert len(stats_list) == self.count, (len(stats_list), self.count) |
|
|
|
values = stats_list[start:] |
|
v = aggregate(values) |
|
summary_writer.add_scalar(key2, v, self.total_count) |
|
|
|
def wandb_log(self, start: int = None, commit: bool = True): |
|
if start is None: |
|
start = 0 |
|
if start < 0: |
|
start = self.count + start |
|
|
|
d = {} |
|
for key2, stats_list in self.stats.items(): |
|
assert len(stats_list) == self.count, (len(stats_list), self.count) |
|
|
|
values = stats_list[start:] |
|
v = aggregate(values) |
|
d[key2] = v |
|
d["iteration"] = self.total_count |
|
wandb.log(d, commit=commit) |
|
|
|
def finished(self) -> None: |
|
self._finished = True |
|
|
|
@contextmanager |
|
def measure_time(self, name: str): |
|
start = time.perf_counter() |
|
yield start |
|
t = time.perf_counter() - start |
|
self.register({name: t}) |
|
|
|
def measure_iter_time(self, iterable, name: str): |
|
iterator = iter(iterable) |
|
while True: |
|
try: |
|
start = time.perf_counter() |
|
retval = next(iterator) |
|
t = time.perf_counter() - start |
|
self.register({name: t}) |
|
yield retval |
|
except StopIteration: |
|
break |
|
|
|
|
|
class Reporter: |
|
"""Reporter class. |
|
|
|
Examples: |
|
|
|
>>> reporter = Reporter() |
|
>>> with reporter.observe('train') as sub_reporter: |
|
... for batch in iterator: |
|
... stats = dict(loss=0.2) |
|
... sub_reporter.register(stats) |
|
|
|
""" |
|
|
|
def __init__(self, epoch: int = 0): |
|
assert check_argument_types() |
|
if epoch < 0: |
|
raise ValueError(f"epoch must be 0 or more: {epoch}") |
|
self.epoch = epoch |
|
|
|
|
|
self.stats = {} |
|
|
|
def get_epoch(self) -> int: |
|
return self.epoch |
|
|
|
def set_epoch(self, epoch: int) -> None: |
|
if epoch < 0: |
|
raise ValueError(f"epoch must be 0 or more: {epoch}") |
|
self.epoch = epoch |
|
|
|
@contextmanager |
|
def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]: |
|
sub_reporter = self.start_epoch(key, epoch) |
|
yield sub_reporter |
|
|
|
self.finish_epoch(sub_reporter) |
|
|
|
def start_epoch(self, key: str, epoch: int = None) -> SubReporter: |
|
if epoch is not None: |
|
if epoch < 0: |
|
raise ValueError(f"epoch must be 0 or more: {epoch}") |
|
self.epoch = epoch |
|
|
|
if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]: |
|
|
|
|
|
if self.epoch - 1 != 0: |
|
warnings.warn( |
|
f"The stats of the previous epoch={self.epoch - 1}" |
|
f"doesn't exist." |
|
) |
|
total_count = 0 |
|
else: |
|
total_count = self.stats[self.epoch - 1][key]["total_count"] |
|
|
|
sub_reporter = SubReporter(key, self.epoch, total_count) |
|
|
|
self.stats.pop(epoch, None) |
|
return sub_reporter |
|
|
|
def finish_epoch(self, sub_reporter: SubReporter) -> None: |
|
if self.epoch != sub_reporter.epoch: |
|
raise RuntimeError( |
|
f"Don't change epoch during observation: " |
|
f"{self.epoch} != {sub_reporter.epoch}" |
|
) |
|
|
|
|
|
stats = {} |
|
for key2, values in sub_reporter.stats.items(): |
|
v = aggregate(values) |
|
stats[key2] = v |
|
|
|
stats["time"] = datetime.timedelta( |
|
seconds=time.perf_counter() - sub_reporter.start_time |
|
) |
|
stats["total_count"] = sub_reporter.total_count |
|
if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): |
|
if torch.cuda.is_initialized(): |
|
stats["gpu_max_cached_mem_GB"] = ( |
|
torch.cuda.max_memory_reserved() / 2 ** 30 |
|
) |
|
else: |
|
if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0: |
|
stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30 |
|
|
|
self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats |
|
sub_reporter.finished() |
|
|
|
def sort_epochs_and_values( |
|
self, key: str, key2: str, mode: str |
|
) -> List[Tuple[int, float]]: |
|
"""Return the epoch which resulted the best value. |
|
|
|
Example: |
|
>>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min') |
|
>>> e_1best, v_1best = val[0] |
|
>>> e_2best, v_2best = val[1] |
|
""" |
|
if mode not in ("min", "max"): |
|
raise ValueError(f"mode must min or max: {mode}") |
|
if not self.has(key, key2): |
|
raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}") |
|
|
|
|
|
values = [(e, self.stats[e][key][key2]) for e in self.stats] |
|
|
|
if mode == "min": |
|
values = sorted(values, key=lambda x: x[1]) |
|
else: |
|
values = sorted(values, key=lambda x: -x[1]) |
|
return values |
|
|
|
def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]: |
|
return [e for e, v in self.sort_epochs_and_values(key, key2, mode)] |
|
|
|
def sort_values(self, key: str, key2: str, mode: str) -> List[float]: |
|
return [v for e, v in self.sort_epochs_and_values(key, key2, mode)] |
|
|
|
def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int: |
|
return self.sort_epochs(key, key2, mode)[nbest] |
|
|
|
def check_early_stopping( |
|
self, |
|
patience: int, |
|
key1: str, |
|
key2: str, |
|
mode: str, |
|
epoch: int = None, |
|
logger=None, |
|
) -> bool: |
|
if logger is None: |
|
logger = logging |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
|
|
best_epoch = self.get_best_epoch(key1, key2, mode) |
|
if epoch - best_epoch > patience: |
|
logger.info( |
|
f"[Early stopping] {key1}.{key2} has not been " |
|
f"improved {epoch - best_epoch} epochs continuously. " |
|
f"The training was stopped at {epoch}epoch" |
|
) |
|
return True |
|
else: |
|
return False |
|
|
|
def has(self, key: str, key2: str, epoch: int = None) -> bool: |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
return ( |
|
epoch in self.stats |
|
and key in self.stats[epoch] |
|
and key2 in self.stats[epoch][key] |
|
) |
|
|
|
def log_message(self, epoch: int = None) -> str: |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
|
|
message = "" |
|
for key, d in self.stats[epoch].items(): |
|
_message = "" |
|
for key2, v in d.items(): |
|
if v is not None: |
|
if len(_message) != 0: |
|
_message += ", " |
|
if isinstance(v, float): |
|
if abs(v) > 1.0e3: |
|
_message += f"{key2}={v:.3e}" |
|
elif abs(v) > 1.0e-3: |
|
_message += f"{key2}={v:.3f}" |
|
else: |
|
_message += f"{key2}={v:.3e}" |
|
elif isinstance(v, datetime.timedelta): |
|
_v = humanfriendly.format_timespan(v) |
|
_message += f"{key2}={_v}" |
|
else: |
|
_message += f"{key2}={v}" |
|
if len(_message) != 0: |
|
if len(message) == 0: |
|
message += f"{epoch}epoch results: " |
|
else: |
|
message += ", " |
|
message += f"[{key}] {_message}" |
|
return message |
|
|
|
def get_value(self, key: str, key2: str, epoch: int = None): |
|
if not self.has(key, key2): |
|
raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}") |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
return self.stats[epoch][key][key2] |
|
|
|
def get_keys(self, epoch: int = None) -> Tuple[str, ...]: |
|
"""Returns keys1 e.g. train,eval.""" |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
return tuple(self.stats[epoch]) |
|
|
|
def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]: |
|
"""Returns keys2 e.g. loss,acc.""" |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
d = self.stats[epoch][key] |
|
keys2 = tuple(k for k in d if k not in ("time", "total_count")) |
|
return keys2 |
|
|
|
def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]: |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
all_keys = [] |
|
for key in self.stats[epoch]: |
|
for key2 in self.stats[epoch][key]: |
|
all_keys.append((key, key2)) |
|
return tuple(all_keys) |
|
|
|
def matplotlib_plot(self, output_dir: Union[str, Path]): |
|
"""Plot stats using Matplotlib and save images.""" |
|
keys2 = set.union(*[set(self.get_keys2(k)) for k in self.get_keys()]) |
|
for key2 in keys2: |
|
keys = [k for k in self.get_keys() if key2 in self.get_keys2(k)] |
|
plt = self._plot_stats(keys, key2) |
|
p = output_dir / f"{key2}.png" |
|
p.parent.mkdir(parents=True, exist_ok=True) |
|
plt.savefig(p) |
|
|
|
def _plot_stats(self, keys: Sequence[str], key2: str): |
|
assert check_argument_types() |
|
|
|
if isinstance(keys, str): |
|
raise TypeError(f"Input as [{keys}]") |
|
|
|
import matplotlib |
|
|
|
matplotlib.use("agg") |
|
import matplotlib.pyplot as plt |
|
import matplotlib.ticker as ticker |
|
|
|
plt.clf() |
|
|
|
epochs = np.arange(1, self.get_epoch() + 1) |
|
for key in keys: |
|
y = [ |
|
self.stats[e][key][key2] |
|
if e in self.stats |
|
and key in self.stats[e] |
|
and key2 in self.stats[e][key] |
|
else np.nan |
|
for e in epochs |
|
] |
|
assert len(epochs) == len(y), "Bug?" |
|
|
|
plt.plot(epochs, y, label=key, marker="x") |
|
plt.legend() |
|
plt.title(f"epoch vs {key2}") |
|
|
|
plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True)) |
|
plt.xlabel("epoch") |
|
plt.ylabel(key2) |
|
plt.grid() |
|
|
|
return plt |
|
|
|
def tensorboard_add_scalar(self, summary_writer: SummaryWriter, epoch: int = None): |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
|
|
for key1 in self.get_keys(epoch): |
|
for key2 in self.stats[epoch][key1]: |
|
if key2 in ("time", "total_count"): |
|
continue |
|
summary_writer.add_scalar( |
|
f"{key1}_{key2}_epoch", |
|
self.stats[epoch][key1][key2], |
|
epoch, |
|
) |
|
|
|
def wandb_log(self, epoch: int = None, commit: bool = True): |
|
if epoch is None: |
|
epoch = self.get_epoch() |
|
|
|
d = {} |
|
for key1 in self.get_keys(epoch): |
|
for key2 in self.stats[epoch][key1]: |
|
if key2 in ("time", "total_count"): |
|
continue |
|
d[f"{key1}_{key2}_epoch"] = self.stats[epoch][key1][key2] |
|
d["epoch"] = epoch |
|
wandb.log(d, commit=commit) |
|
|
|
def state_dict(self): |
|
return {"stats": self.stats, "epoch": self.epoch} |
|
|
|
def load_state_dict(self, state_dict: dict): |
|
self.epoch = state_dict["epoch"] |
|
self.stats = state_dict["stats"] |
|
|