lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
Benchmark the timing a block of code
Authors
* Leo 2022
"""
import logging
from collections import defaultdict
from contextlib import ContextDecorator
from time import time
from typing import Any
import numpy as np
import torch
logger = logging.getLogger(__name__)
_history = defaultdict(list)
__all__ = ["benchmark"]
class benchmark(ContextDecorator):
def __init__(self, name: str, freq: int = 1) -> None:
super().__init__()
self.name = name
self.freq = freq
def __enter__(self):
torch.cuda.synchronize()
self.start = time()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.cuda.synchronize()
seconds = time() - self.start
global _history
_history[self.name].append(seconds)
if len(_history[self.name]) % self.freq == 0:
logger.warning(
f"{self.name}: {seconds} secs, avg {np.array(_history[self.name]).mean()} secs"
)