File size: 2,434 Bytes
1fd4e9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import time


class TimeEstimator:

    def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7):
        self.avg_time_window = []  # window-based average
        self.exp_avg_time = None  # exponential moving average
        self.alpha = ema_alpha  # for exponential moving average

        self.last_time = time.time()  # would not be accurate for the first iteration but well
        self.total_iter = total_iter
        self.step_size = step_size

        self._buffering_exp = True

    # call this at a fixed interval
    # does not have to be every step
    def update(self):
        curr_time = time.time()
        time_per_iter = curr_time - self.last_time
        self.last_time = curr_time

        self.avg_time_window.append(time_per_iter)

        if self._buffering_exp:
            if self.exp_avg_time is not None:
                # discard the first iteration call to not pollute the ema
                self._buffering_exp = False
            self.exp_avg_time = time_per_iter
        else:
            self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter

    def get_est_remaining(self, it: int):
        if self.exp_avg_time is None:
            return 0

        remaining_iter = self.total_iter - it
        return remaining_iter * self.exp_avg_time / self.step_size

    def get_and_reset_avg_time(self):
        avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
        self.avg_time_window = []
        return avg


class PartialTimeEstimator(TimeEstimator):
    """
    Used where the start_time and the end_time do not align
    """

    def update(self):
        raise RuntimeError('Please use start() and end() for PartialTimeEstimator')

    def start(self):
        self.last_time = time.time()

    def end(self):
        assert self.last_time is not None, 'Please call start() before calling end()'
        curr_time = time.time()
        time_per_iter = curr_time - self.last_time
        self.last_time = None

        self.avg_time_window.append(time_per_iter)

        if self._buffering_exp:
            if self.exp_avg_time is not None:
                # discard the first iteration call to not pollute the ema
                self._buffering_exp = False
            self.exp_avg_time = time_per_iter
        else:
            self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter