|
import torch.nn.functional as F |
|
import torch |
|
from megatron import print_rank_0, get_args, mpu |
|
from megatron.data.vit_dataset import ClassificationTransform |
|
from megatron.data.image_folder import ImageFolder |
|
|
|
_FEATURE_BANK = None |
|
|
|
|
|
def build_data_loader(dataset, drop_last=True, shuffle=False): |
|
"""Data loader. Note that batch-size is the local (per GPU) batch-size.""" |
|
|
|
args = get_args() |
|
micro_batch_size = 16 |
|
num_workers = args.num_workers |
|
world_size = mpu.get_data_parallel_world_size() |
|
rank = mpu.get_data_parallel_rank() |
|
sampler = torch.utils.data.distributed.DistributedSampler( |
|
dataset, num_replicas=world_size, rank=rank, |
|
drop_last=drop_last, shuffle=shuffle |
|
) |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=micro_batch_size, |
|
sampler=sampler, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
drop_last=not drop_last, |
|
pin_memory=True, |
|
) |
|
return data_loader |
|
|
|
|
|
def compute_feature_bank(model): |
|
args = get_args() |
|
global _FEATURE_BANK |
|
feature_bank = [] |
|
feature_label = [] |
|
|
|
train_ds = ImageFolder( |
|
root=args.data_path[0], |
|
transform=ClassificationTransform((args.img_h, args.img_w), train=False), |
|
data_per_class_fraction=1.0 |
|
) |
|
classes = len(train_ds.classes) |
|
dataloader = build_data_loader(train_ds) |
|
|
|
for m in model: |
|
m.eval() |
|
|
|
with torch.no_grad(): |
|
for i, batch in enumerate(dataloader): |
|
images = batch[0].cuda().contiguous() |
|
labels = batch[1].cuda().contiguous() |
|
student_feature, teacher_feature = model[0](images) |
|
feature = F.normalize(teacher_feature.float(), dim=1) |
|
feature_bank.append(feature) |
|
feature_label.append(labels) |
|
|
|
for m in model: |
|
m.train() |
|
|
|
|
|
feature_bank = torch.cat(feature_bank, dim=0).contiguous() |
|
feature_label = torch.cat(feature_label, dim=0).contiguous() |
|
|
|
feature_banks = [torch.zeros_like(feature_bank) |
|
for i in range(mpu.get_data_parallel_world_size())] |
|
torch.distributed.all_gather(feature_banks, |
|
feature_bank, |
|
group=mpu.get_data_parallel_group()) |
|
|
|
assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()], |
|
feature_bank)) |
|
|
|
feature_labels = [torch.zeros_like(feature_label) |
|
for i in range(mpu.get_data_parallel_world_size())] |
|
torch.distributed.all_gather(feature_labels, |
|
feature_label, |
|
group=mpu.get_data_parallel_group()) |
|
|
|
|
|
feature_banks = torch.cat(feature_banks, dim=0).t().contiguous() |
|
|
|
feature_labels = torch.cat(feature_labels, dim=0).contiguous() |
|
print_rank_0("feature_banks size is {}".format(feature_banks.size())) |
|
print_rank_0("feature labels size is {}".format(feature_labels.size())) |
|
|
|
_FEATURE_BANK = (feature_banks, feature_labels, classes) |
|
|
|
|
|
def get_feature_bank(): |
|
global _FEATURE_BANK |
|
assert _FEATURE_BANK is not None |
|
return _FEATURE_BANK |
|
|
|
|
|
|
|
|
|
|
|
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): |
|
|
|
sim_matrix = torch.mm(feature, feature_bank) |
|
|
|
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) |
|
|
|
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), |
|
dim=-1, |
|
index=sim_indices) |
|
sim_weight = (sim_weight / knn_t).exp() |
|
|
|
|
|
one_hot_label = torch.zeros(feature.size(0) * knn_k, |
|
classes, |
|
device=sim_labels.device) |
|
|
|
one_hot_label = one_hot_label.scatter(dim=-1, |
|
index=sim_labels.view(-1, 1), |
|
value=1.0) |
|
|
|
pred_scores = torch.sum( |
|
one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), |
|
dim=1) |
|
|
|
pred_labels = pred_scores.argsort(dim=-1, descending=True) |
|
return pred_labels |
|
|