File size: 4,590 Bytes
23bd7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder

_FEATURE_BANK = None


def build_data_loader(dataset, drop_last=True, shuffle=False):
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""
    # Sampler.
    args = get_args()
    micro_batch_size = 16
    num_workers = args.num_workers
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank,
        drop_last=drop_last, shuffle=shuffle
    )

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=micro_batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        drop_last=not drop_last,
        pin_memory=True,
    )
    return data_loader


def compute_feature_bank(model):
    args = get_args()
    global _FEATURE_BANK
    feature_bank = []
    feature_label = []

    train_ds = ImageFolder(
        root=args.data_path[0],
        transform=ClassificationTransform((args.img_h, args.img_w), train=False),
        data_per_class_fraction=1.0
    )
    classes = len(train_ds.classes)
    dataloader = build_data_loader(train_ds)
     
    for m in model:
        m.eval()

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            images = batch[0].cuda().contiguous()
            labels = batch[1].cuda().contiguous()
            student_feature, teacher_feature = model[0](images)
            feature = F.normalize(teacher_feature.float(), dim=1)
            feature_bank.append(feature)
            feature_label.append(labels)
    
    for m in model:
        m.train()

    # [N', D]
    feature_bank = torch.cat(feature_bank, dim=0).contiguous()
    feature_label = torch.cat(feature_label, dim=0).contiguous()

    feature_banks = [torch.zeros_like(feature_bank)
                     for i in range(mpu.get_data_parallel_world_size())]
    torch.distributed.all_gather(feature_banks,
                                 feature_bank,
                                 group=mpu.get_data_parallel_group())

    assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
                              feature_bank))

    feature_labels = [torch.zeros_like(feature_label)
                      for i in range(mpu.get_data_parallel_world_size())]
    torch.distributed.all_gather(feature_labels,
                                 feature_label,
                                 group=mpu.get_data_parallel_group())

    # [D, N]
    feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
    # [N]
    feature_labels = torch.cat(feature_labels, dim=0).contiguous()
    print_rank_0("feature_banks size is {}".format(feature_banks.size()))
    print_rank_0("feature labels size is {}".format(feature_labels.size()))

    _FEATURE_BANK = (feature_banks, feature_labels, classes)


def get_feature_bank():
    global _FEATURE_BANK
    assert _FEATURE_BANK is not None
    return _FEATURE_BANK


# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
                              dim=-1,
                              index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k,
                                classes,
                                device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1,
                                          index=sim_labels.view(-1, 1),
                                          value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(
            one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
            dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels