onescotch
clean up for zero gpus
010a8bc
raw
history blame
3.74 kB
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
import numpy as np
from mmengine import Timer
class RunningAverage():
r"""A helper class to calculate running average in a sliding window.
Args:
window (int): The size of the sliding window.
"""
def __init__(self, window: int = 1):
self.window = window
self._data = []
def update(self, value):
"""Update a new data sample."""
self._data.append(value)
self._data = self._data[-self.window:]
def average(self):
"""Get the average value of current window."""
return np.mean(self._data)
class StopWatch:
r"""A helper class to measure FPS and detailed time consuming of each phase
in a video processing loop or similar scenarios.
Args:
window (int): The sliding window size to calculate the running average
of the time consuming.
Example:
>>> from mmpose.utils import StopWatch
>>> import time
>>> stop_watch = StopWatch(window=10)
>>> with stop_watch.timeit('total'):
>>> time.sleep(0.1)
>>> # 'timeit' support nested use
>>> with stop_watch.timeit('phase1'):
>>> time.sleep(0.1)
>>> with stop_watch.timeit('phase2'):
>>> time.sleep(0.2)
>>> time.sleep(0.2)
>>> report = stop_watch.report()
"""
def __init__(self, window=1):
self.window = window
self._record = defaultdict(partial(RunningAverage, window=self.window))
self._timer_stack = []
@contextmanager
def timeit(self, timer_name='_FPS_'):
"""Timing a code snippet with an assigned name.
Args:
timer_name (str): The unique name of the interested code snippet to
handle multiple timers and generate reports. Note that '_FPS_'
is a special key that the measurement will be in `fps` instead
of `millisecond`. Also see `report` and `report_strings`.
Default: '_FPS_'.
Note:
This function should always be used in a `with` statement, as shown
in the example.
"""
self._timer_stack.append((timer_name, Timer()))
try:
yield
finally:
timer_name, timer = self._timer_stack.pop()
self._record[timer_name].update(timer.since_start())
def report(self, key=None):
"""Report timing information.
Returns:
dict: The key is the timer name and the value is the \
corresponding average time consuming.
"""
result = {
name: r.average() * 1000.
for name, r in self._record.items()
}
if '_FPS_' in result:
result['_FPS_'] = 1000. / result.pop('_FPS_')
if key is None:
return result
return result[key]
def report_strings(self):
"""Report timing information in texture strings.
Returns:
list(str): Each element is the information string of a timed \
event, in format of '{timer_name}: {time_in_ms}'. \
Specially, if timer_name is '_FPS_', the result will \
be converted to fps.
"""
result = self.report()
strings = []
if '_FPS_' in result:
strings.append(f'FPS: {result["_FPS_"]:>5.1f}')
strings += [f'{name}: {val:>3.0f}' for name, val in result.items()]
return strings
def reset(self):
self._record = defaultdict(list)
self._active_timer_stack = []