""" Benchmark the timing a block of code Authors * Leo 2022 """ import logging from collections import defaultdict from contextlib import ContextDecorator from time import time from typing import Any import numpy as np import torch logger = logging.getLogger(__name__) _history = defaultdict(list) __all__ = ["benchmark"] class benchmark(ContextDecorator): def __init__(self, name: str, freq: int = 1) -> None: super().__init__() self.name = name self.freq = freq def __enter__(self): torch.cuda.synchronize() self.start = time() def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch.cuda.synchronize() seconds = time() - self.start global _history _history[self.name].append(seconds) if len(_history[self.name]) % self.freq == 0: logger.warning( f"{self.name}: {seconds} secs, avg {np.array(_history[self.name]).mean()} secs" )