File size: 13,249 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
Schedulers   Script  ver: Feb 15th 17:00

puzzle_patch_scheduler is used to arrange patch size for multi-scale learning

ref
lr_scheduler from MAE code.
https://github.com/facebookresearch/mae
"""

import math
import random


def factor(num):
    """
    find factor of input num
    """
    factors = []
    for_times = int(math.sqrt(num))
    for i in range(for_times + 1)[1:]:
        if num % i == 0:
            factors.append(i)
            t = int(num / i)
            if not t == i:
                factors.append(t)
    return factors


def defactor(num_list, basic_num):  # check multiples
    array = []
    for i in num_list:
        if i // basic_num * basic_num - i == 0:
            array.append(i)
    array.sort()  # accend
    return array


def adjust_learning_rate(optimizer, epoch, args):
    """
    Decay the learning rate with half-cycle cosine after warmup
    epoch,ok with float,to be more flexible,
    like: data_iter_step / len(data_loader) + epoch
    """
    # calculate the lr for this time
    if epoch < args.warmup_epochs:  # for warmup
        lr = args.lr * epoch / args.warmup_epochs  # lr increase from zero to the setted lr

    else:  # after warmup do cosin lr decay
        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
             (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))

    # update lr in the optmizer
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    return lr


class patch_scheduler:
    """
    this is used to drive the patch size by loss and epoch
    the patch list is automatically get
    """

    def __init__(self, total_epoches=200, warmup_epochs=20, edge_size=384, basic_patch=16, strategy=None,
                 threshold=3.0, reducing_factor=0.933, fix_patch_size=None, patch_size_jump=None):
        super().__init__()

        self.strategy = strategy

        self.total_epoches = total_epoches
        self.warmup_epochs = warmup_epochs

        # automatically build legal patch list, from small to big size
        self.patch_list = defactor(factor(edge_size), basic_patch)

        self.threshold = threshold
        self.reducing_factor = reducing_factor
        self.fix_patch_size = fix_patch_size

        # from small to big patch, No need for patch at all fig level
        if len(self.patch_list) > 1:
            self.patch_list = self.patch_list[:-1]

        # jump_patch_list by selecting the 'odd' or 'even', but both with the smallest patch size
        if patch_size_jump == 'odd':  # 384:[196, 96, 48, 16]
            jump_patch_list = self.patch_list[0::2]
            self.patch_list = jump_patch_list
        elif patch_size_jump == 'even':  # 384:[128, 64, 32, 16]
            jump_patch_list = self.patch_list[1::2]
            # add back the smallest
            temp_list = [self.patch_list[0]]
            temp_list.extend(jump_patch_list)
            self.patch_list = temp_list
        else:  # all
            pass

        if self.strategy in ['reverse', 'loss_back', 'loss_hold']:  # start from big(easy) to samll(complex)
            self.patch_list.sort(reverse=True)

        if self.strategy is None or self.strategy == 'fixed':
            puzzle_patch_size = self.fix_patch_size or self.patch_list[0]
            print('patch_list:', puzzle_patch_size)
        else:
            print('patch_list:', self.patch_list)

        # self.loss_log ?

    def __call__(self, epoch, loss=0.0):
        # Designed for flexable ablations
        if self.strategy == 'linear' or self.strategy == 'reverse':  # reverse from big size to small
            if epoch < self.warmup_epochs:  # warmup
                puzzle_patch_size = 32  # fixed size for warmup
            else:
                puzzle_patch_size = self.patch_list[min(int((epoch - self.warmup_epochs)
                                                            / (self.total_epoches - self.warmup_epochs)
                                                            * len(self.patch_list)), len(self.patch_list) - 1)]

        elif self.strategy == 'loop':
            # looply change the patch size, after [group_size] epoches we change once
            group_size = int(self.threshold)

            if epoch < self.warmup_epochs:
                puzzle_patch_size = 32  # in warm up epoches, fixed patch size at 32 fixme exploring
            else:
                group_idx = (epoch - self.warmup_epochs) % (len(self.patch_list) * group_size)
                puzzle_patch_size = self.patch_list[int(group_idx / group_size)]

        elif self.strategy == 'random':  # random size strategy
            puzzle_patch_size = random.choice(self.patch_list)

        elif self.strategy == 'loss_back':
            if epoch < self.warmup_epochs:  # for warmup
                puzzle_patch_size = 32  # in warm-up we use the fix size
            else:
                if loss == 0.0:
                    puzzle_patch_size = self.patch_list[min(int((epoch - self.warmup_epochs)
                                                                / (self.total_epoches - self.warmup_epochs)
                                                                * len(self.patch_list)), len(self.patch_list) - 1)]

                elif loss < self.threshold:
                    puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs)
                                                                    / (self.total_epoches - self.warmup_epochs)
                                                                    * len(self.patch_list)) + 1, 0),
                                                            len(self.patch_list) - 1)]
                    self.threshold *= self.reducing_factor
                else:
                    puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs)
                                                                    / (self.total_epoches - self.warmup_epochs)
                                                                    * len(self.patch_list)) - 1, 0),
                                                            len(self.patch_list) - 1)]

        elif self.strategy == 'loss_hold':
            if epoch < self.warmup_epochs:  # for warmup
                puzzle_patch_size = 32  # in warm-up we use the fix size
            else:
                if loss == 0.0:
                    puzzle_patch_size = self.patch_list[min(int((epoch - self.warmup_epochs)
                                                                / (self.total_epoches - self.warmup_epochs)
                                                                * len(self.patch_list)), len(self.patch_list) - 1)]

                elif loss < self.threshold:
                    puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs)
                                                                    / (self.total_epoches - self.warmup_epochs)
                                                                    * len(self.patch_list)) + 1, 0),
                                                            len(self.patch_list) - 1)]
                    self.threshold *= self.reducing_factor
                else:
                    puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs)
                                                                    / (self.total_epoches - self.warmup_epochs)
                                                                    * len(self.patch_list)), 0),
                                                            len(self.patch_list) - 1)]

        else:
            # if self.strategy is None or 'fixed' or 'ratio-decay'
            puzzle_patch_size = self.fix_patch_size or self.patch_list[0]  # basic_patch

        return puzzle_patch_size


class ratio_scheduler:
    """
        this is used to drive the fix position ratio by loss and epoch
        the ratio is control by ratio_floor_factor=0.5, upper_limit=0.9, lower_limit=0.2
    """
    def __init__(self, total_epoches=200, warmup_epochs=20, basic_ratio=0.25, strategy=None, fix_position_ratio=None,
                 threshold=4.0, loss_reducing_factor=0.933, ratio_floor_factor=0.5, upper_limit=0.9, lower_limit=0.2):

        # fixme basic_ratio and fix_position_ratio(when stage is fixed) is a bit conflicting, not good enough
        super().__init__()
        self.strategy = strategy

        self.total_epoches = total_epoches
        self.warmup_epochs = warmup_epochs

        self.basic_ratio = basic_ratio

        self.threshold = threshold
        self.loss_reducing_factor = loss_reducing_factor

        self.fix_position_ratio = fix_position_ratio

        self.upper_limit = upper_limit
        self.lower_limit = lower_limit
        self.ratio_floor_factor = ratio_floor_factor

    def __call__(self, epoch, loss=0.0):
        if self.strategy == 'ratio-decay' or self.strategy == 'decay':
            if epoch < self.warmup_epochs:  # for warmup
                fix_position_ratio = self.basic_ratio  # fixed
            else:
                max_ratio = min(3 * self.basic_ratio, self.upper_limit)  # upper-limit of 0.9
                min_ratio = max(self.basic_ratio * self.ratio_floor_factor, self.lower_limit)

                fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                              - (epoch - self.warmup_epochs)) /
                                             (self.total_epoches - self.warmup_epochs)
                                             * max_ratio, min_ratio), max_ratio)

        elif self.strategy == 'loss_back':

            if epoch < self.warmup_epochs:  # for warmup
                fix_position_ratio = self.basic_ratio  # in warm-up we use the fix ratio

            else:
                max_ratio = min(3 * self.basic_ratio, self.upper_limit)
                min_ratio = max(self.basic_ratio * self.ratio_floor_factor, self.lower_limit)
                if loss == 0.0:
                    fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                                  - (epoch - self.warmup_epochs)) /
                                                 (self.total_epoches - self.warmup_epochs)
                                                 * max_ratio, min_ratio), max_ratio)
                elif loss < self.threshold:
                    fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                                  - (epoch - self.warmup_epochs)) /
                                                 (self.total_epoches - self.warmup_epochs)
                                                 * max_ratio * 0.9, min_ratio), max_ratio)
                    self.threshold *= self.loss_reducing_factor
                else:
                    fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                                  - (epoch - self.warmup_epochs)) /
                                                 (self.total_epoches - self.warmup_epochs)
                                                 * max_ratio * 1.1, min_ratio), max_ratio)

        elif self.strategy == 'loss_hold':

            if epoch < self.warmup_epochs:  # for warmup
                fix_position_ratio = self.basic_ratio  # in warm-up we use the fix ratio

            else:
                max_ratio = min(3 * self.basic_ratio, self.upper_limit)
                min_ratio = max(self.basic_ratio * self.ratio_floor_factor, self.lower_limit)

                if loss == 0.0:
                    fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                                  - (epoch - self.warmup_epochs)) /
                                                 (self.total_epoches - self.warmup_epochs)
                                                 * max_ratio, min_ratio), max_ratio)
                elif loss < self.threshold:
                    fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                                  - (epoch - self.warmup_epochs)) /
                                                 (self.total_epoches - self.warmup_epochs)
                                                 * max_ratio * 0.9, min_ratio), max_ratio)
                    self.threshold *= self.loss_reducing_factor
                else:
                    fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs)
                                                  - (epoch - self.warmup_epochs)) /
                                                 (self.total_epoches - self.warmup_epochs)
                                                 * max_ratio, min_ratio), max_ratio)

        else:  # basic_ratio
            fix_position_ratio = self.fix_position_ratio or self.basic_ratio

        return fix_position_ratio


'''
scheduler = puzzle_fix_position_ratio_scheduler(strategy='reverse')
epoch = 102
fix_position_ratio = scheduler(epoch)
print(fix_position_ratio)
'''