File size: 3,742 Bytes
2de1f98
 
 
 
 
 
010a8bc
2de1f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
010a8bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from contextlib import contextmanager
from functools import partial

import numpy as np
from mmengine import Timer


class RunningAverage():
    r"""A helper class to calculate running average in a sliding window.

    Args:
        window (int): The size of the sliding window.
    """

    def __init__(self, window: int = 1):
        self.window = window
        self._data = []

    def update(self, value):
        """Update a new data sample."""
        self._data.append(value)
        self._data = self._data[-self.window:]

    def average(self):
        """Get the average value of current window."""
        return np.mean(self._data)


class StopWatch:
    r"""A helper class to measure FPS and detailed time consuming of each phase
    in a video processing loop or similar scenarios.

    Args:
        window (int): The sliding window size to calculate the running average
            of the time consuming.

    Example:
        >>> from mmpose.utils import StopWatch
        >>> import time
        >>> stop_watch = StopWatch(window=10)
        >>> with stop_watch.timeit('total'):
        >>>     time.sleep(0.1)
        >>>     # 'timeit' support nested use
        >>>     with stop_watch.timeit('phase1'):
        >>>         time.sleep(0.1)
        >>>     with stop_watch.timeit('phase2'):
        >>>         time.sleep(0.2)
        >>>     time.sleep(0.2)
        >>> report = stop_watch.report()
    """

    def __init__(self, window=1):
        self.window = window
        self._record = defaultdict(partial(RunningAverage, window=self.window))
        self._timer_stack = []

    @contextmanager
    def timeit(self, timer_name='_FPS_'):
        """Timing a code snippet with an assigned name.

        Args:
            timer_name (str): The unique name of the interested code snippet to
                handle multiple timers and generate reports. Note that '_FPS_'
                is a special key that the measurement will be in `fps` instead
                of `millisecond`. Also see `report` and `report_strings`.
                Default: '_FPS_'.
        Note:
            This function should always be used in a `with` statement, as shown
            in the example.
        """
        self._timer_stack.append((timer_name, Timer()))
        try:
            yield
        finally:
            timer_name, timer = self._timer_stack.pop()
            self._record[timer_name].update(timer.since_start())

    def report(self, key=None):
        """Report timing information.

        Returns:
            dict: The key is the timer name and the value is the \
                corresponding average time consuming.
        """
        result = {
            name: r.average() * 1000.
            for name, r in self._record.items()
        }

        if '_FPS_' in result:
            result['_FPS_'] = 1000. / result.pop('_FPS_')

        if key is None:
            return result
        return result[key]

    def report_strings(self):
        """Report timing information in texture strings.

        Returns:
            list(str): Each element is the information string of a timed \
                event, in format of '{timer_name}: {time_in_ms}'. \
                Specially, if timer_name is '_FPS_', the result will \
                be converted to fps.
        """
        result = self.report()
        strings = []
        if '_FPS_' in result:
            strings.append(f'FPS: {result["_FPS_"]:>5.1f}')
        strings += [f'{name}: {val:>3.0f}' for name, val in result.items()]
        return strings

    def reset(self):
        self._record = defaultdict(list)
        self._active_timer_stack = []