File size: 10,929 Bytes
23bd7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.

# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import math
import apex
import einops
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin


class DINOLoss(torch.nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))
        self.teacher_temp = teacher_temp

    def forward(self, student_output, teacher_output, iteration):
        """
        Cross-entropy between softmax outputs of the teacher
        and student network.
        """
        args = get_args()
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        epoch = iteration // args.iter_per_epoch

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)

        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        torch.distributed.all_reduce(batch_center)
        batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

class DINOHead(torch.nn.Module):
    def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
        super().__init__()
        args = get_args()
        hidden_dim = args.dino_head_hidden_size
        bottleneck_dim = args.dino_bottleneck_size
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [torch.nn.Linear(in_dim, hidden_dim)]
            layers.append(torch.nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
                layers.append(torch.nn.GELU())
            layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = torch.nn.Sequential(*layers)
        self.apply(self._init_weights)
        self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, torch.nn.Linear) and m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = torch.nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x


class MultiCropWrapper(MegatronModule):

    """
    Perform forward pass separately on each resolution input.
    The inputs corresponding to a single resolution are clubbed and single
    forward is run on the same resolution inputs. Hence we do several
    forward passes = number of different resolutions used. We then
    concatenate all the output features and run the head forward on these
    concatenated features.
    """
    def __init__(self, backbone, head):
        super(MultiCropWrapper, self).__init__()
        # disable layers dedicated to ImageNet labels classification
        #backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)

        start_idx = 0
        for end_idx in idx_crops:
            _out = self.backbone(torch.cat(x[start_idx: end_idx]))
            if start_idx == 0:
                output = _out
            else:
                output = torch.cat((output, _out))
            start_idx = end_idx
        # Run the head forward on the concatenated features.
        if self.training:
            return self.head(output)
        else:
            return output


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
                     warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = \
                np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) \
        * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def get_student_backbone_and_num_features(pre_process=True, post_process=True):
    args = get_args()

    if args.vision_backbone_type == 'vit':
        student = VitBackbone(pre_process=pre_process,
                              post_process=post_process,
                              drop_path_rate=0.1,
                              single_token_output=True)
        num_features = args.hidden_size
    elif args.vision_backbone_type == 'mit':
        student = mit_b5_avg(drop_path_rate=0.1)
        num_features = 512
    elif args.vision_backbone_type == 'swin':
        student = get_swin()
        num_features = student.num_features
    else:
        raise Exception('{} vision backbone is not supported.'.format(
                              args.vision_backbone_type))
 
    return student, num_features

def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
    args = get_args()

    if args.vision_backbone_type == 'vit':
        teacher = VitBackbone(pre_process=pre_process,
                              post_process=post_process,
                              single_token_output=True)
        num_features = args.hidden_size
    elif args.vision_backbone_type == 'mit':
        teacher = mit_b5_avg(drop_path_rate=0.0)
        num_features = 512
    elif args.vision_backbone_type == 'swin':
        teacher = get_swin(is_teacher=True)
        num_features = teacher.num_features
    else:
        raise Exception('{} vision backbone is not supported.'.format(
                              args.vision_backbone_type))
    return teacher, num_features


class DINOPretrainModel(MegatronModule):
    def __init__(self, pre_process=True, post_process=True):
        super(DINOPretrainModel, self).__init__()
        args = get_args()
        self.out_dim = 65536

        self.dino_loss = DINOLoss(
            self.out_dim,
            args.dino_local_crops_number + 2,
            args.dino_warmup_teacher_temp,
            args.dino_teacher_temp,
            args.dino_warmup_teacher_temp_epochs,
            300,
        )

        self.pre_process = pre_process
        self.post_process = post_process
        self.momentum_teacher = 0.996

        student_backbone, num_features = \
            get_student_backbone_and_num_features(pre_process, post_process)

        self.student = MultiCropWrapper(
            student_backbone,
            DINOHead(num_features, self.out_dim,
                     norm_last_layer=args.dino_norm_last_layer)
        )

        self.momentum_schedule = cosine_scheduler(
            self.momentum_teacher, 1,
            args.train_iters // args.iter_per_epoch,
            args.iter_per_epoch
        )

        teacher_backbone, num_features = \
            get_teacher_backbone_and_num_features(pre_process, post_process)
        self.teacher = MultiCropWrapper(
            teacher_backbone,
            DINOHead(num_features, self.out_dim)
        )
        self.teacher.load_state_dict(self.student.state_dict())

        for p in self.teacher.parameters():
            if hasattr(p, "requires_grad") and p.requires_grad is not None:
                p.requires_grad = False

    def set_input_tensor(self, tensor):
        pass

    def forward(self, input):
        student_output = None
        if self.training:
            student_output = self.student(input)
            teacher_output = self.teacher(input[:2])
        else:
            teacher_output = self.teacher(input)
        return student_output, teacher_output

    def cancel_gradients_last_layer(self, iteration):
        args = get_args()
        epoch = iteration // args.iter_per_epoch
        if epoch < args.dino_freeze_last_layer:
            for n, p in self.student.named_parameters():
                if "last_layer" in n:
                    p.grad = None

    def update_momentum(self, iteration):
        with torch.no_grad():
            m = self.momentum_schedule[iteration]
            for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)