from __future__ import print_function from __future__ import division import torch import torch.nn as nn from torch.nn import Parameter import torch.distributed as dist import math def l2_norm(input, axis=1): norm = torch.norm(input, p=2, dim=axis, keepdim=True) output = torch.div(input, norm) return output def calc_logits(embeddings, kernel): """ calculate original logits """ embeddings = l2_norm(embeddings, axis=1) kernel_norm = l2_norm(kernel, axis=0) cos_theta = torch.mm(embeddings, kernel_norm) cos_theta = cos_theta.clamp(-1, 1) # for numerical stability with torch.no_grad(): origin_cos = cos_theta.clone() return cos_theta, origin_cos @torch.no_grad() def all_gather_tensor(input_tensor): """ allgather tensor (difference size in 0-dim) from all workers """ world_size = dist.get_world_size() tensor_size = torch.tensor([input_tensor.shape[0]], dtype=torch.int64).cuda() tensor_size_list = [torch.ones_like(tensor_size) for _ in range(world_size)] dist.all_gather(tensor_list=tensor_size_list, tensor=tensor_size, async_op=False) max_size = torch.cat(tensor_size_list, dim=0).max() padded = torch.empty(max_size.item(), *input_tensor.shape[1:], dtype=input_tensor.dtype).cuda() padded[:input_tensor.shape[0]] = input_tensor padded_list = [torch.ones_like(padded) for _ in range(world_size)] dist.all_gather(tensor_list=padded_list, tensor=padded, async_op=False) slices = [] for ts, t in zip(tensor_size_list, padded_list): slices.append(t[:ts.item()]) return torch.cat(slices, dim=0) def calc_top1_acc(original_logits, label,ddp=False): """ Compute the top1 accuracy during training :param original_logits: logits w/o margin, [bs, C] :param label: labels [bs] :return: acc in all gpus """ assert (original_logits.size()[0] == label.size()[0]) with torch.no_grad(): _, max_index = torch.max(original_logits, dim=1, keepdim=False) # local max logit count = (max_index == label).sum() if ddp: dist.all_reduce(count, dist.ReduceOp.SUM) return count.item() / (original_logits.size()[0] * dist.get_world_size()) else: return count.item() / (original_logits.size()[0]) def l2_norm(input, axis=1): norm = torch.norm(input, p=2, dim=axis, keepdim=True) output = torch.div(input, norm) return output class FC_ddp2(nn.Module): """ Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition) No model parallel is used """ def __init__(self, in_features, out_features, scale=64.0, margin=0.4, mode='cosface', use_cifp=False, reduction='mean', ddp=False): """ Args: in_features: size of each input features out_features: size of each output features scale: norm of input feature margin: margin """ super(FC_ddp2, self).__init__() self.in_features = in_features self.out_features = out_features # num of classes self.scale = scale self.margin = margin self.mode = mode self.use_cifp = use_cifp self.kernel = Parameter(torch.Tensor(in_features, out_features)) self.ddp = ddp nn.init.normal_(self.kernel, std=0.01) self.criteria = torch.nn.CrossEntropyLoss(reduction=reduction) def apply_margin(self, target_cos_theta): assert self.mode in ['cosface', 'arcface'], 'Please check the mode' if self.mode == 'arcface': cos_m = math.cos(self.margin) sin_m = math.sin(self.margin) theta = math.cos(math.pi - self.margin) sinmm = math.sin(math.pi - self.margin) * self.margin sin_theta = torch.sqrt(1.0 - torch.pow(target_cos_theta, 2)) cos_theta_m = target_cos_theta * cos_m - sin_theta * sin_m target_cos_theta_m = torch.where( target_cos_theta > theta, cos_theta_m, target_cos_theta - sinmm) elif self.mode == 'cosface': target_cos_theta_m = target_cos_theta - self.margin return target_cos_theta_m def forward(self, embeddings, label, return_logits=False): """ :param embeddings: local gpu [bs, 512] :param label: local labels [bs] :param return_logits: bool :return: loss: computed local loss, w/wo CIFP acc: local accuracy in one gpu output: local logits with margins, with gradients, scaled, [bs, C]. """ sample_num = embeddings.size(0) if not self.use_cifp: cos_theta, origin_cos = calc_logits(embeddings, self.kernel) target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) target_cos_theta_m = self.apply_margin(target_cos_theta) cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) else: cos_theta, origin_cos = calc_logits(embeddings, self.kernel) cos_theta_, _ = calc_logits(embeddings, self.kernel.detach()) mask = torch.zeros_like(cos_theta) # [bs,C] mask.scatter_(1, label.view(-1, 1).long(), 1.0) # one-hot label / gt mask tmp_cos_theta = cos_theta - 2 * mask tmp_cos_theta_ = cos_theta_ - 2 * mask target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) target_cos_theta_ = cos_theta_[torch.arange(0, sample_num), label].view(-1, 1) target_cos_theta_m = self.apply_margin(target_cos_theta) far = 1 / (self.out_features - 1) # ru+ value # far = 1e-5 topk_mask = torch.greater(tmp_cos_theta, target_cos_theta) topk_sum = torch.sum(topk_mask.to(torch.int32)) if self.ddp: dist.all_reduce(topk_sum) far_rank = math.ceil(far * (sample_num * (self.out_features - 1) * dist.get_world_size() - topk_sum)) cos_theta_neg_topk = torch.topk((tmp_cos_theta - 2 * topk_mask.to(torch.float32)).flatten(), k=far_rank)[0] # [far_rank] cos_theta_neg_topk = all_gather_tensor(cos_theta_neg_topk.contiguous()) # top k across all gpus cos_theta_neg_th = torch.topk(cos_theta_neg_topk, k=far_rank)[0][-1] cond = torch.mul(torch.bitwise_not(topk_mask), torch.greater(tmp_cos_theta, cos_theta_neg_th)) cos_theta_neg_topk = torch.mul(cond.to(torch.float32), tmp_cos_theta) cos_theta_neg_topk_ = torch.mul(cond.to(torch.float32), tmp_cos_theta_) cond = torch.greater(target_cos_theta_m, cos_theta_neg_topk) cos_theta_neg_topk = torch.where(cond, cos_theta_neg_topk, cos_theta_neg_topk_) cos_theta_neg_topk = torch.pow(cos_theta_neg_topk, 2) # F = z^p = cos^2 times = torch.sum(torch.greater(cos_theta_neg_topk, 0).to(torch.float32), dim=1, keepdim=True) times = torch.where(torch.greater(times, 0), times, torch.ones_like(times)) cos_theta_neg_topk = torch.sum(cos_theta_neg_topk, dim=1, keepdim=True) / times # ri+/ru+ target_cos_theta_m = target_cos_theta_m - (1 + target_cos_theta_) * cos_theta_neg_topk cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) output = cos_theta * self.scale loss = self.criteria(output, label) acc = calc_top1_acc(origin_cos * self.scale, label,self.ddp) if return_logits: return loss, acc, output return loss, acc class FC_ddp(nn.Module): """ Implement of (CVPR2021 Consistent Instance False Positive Improves Fairness in Face Recognition) No model parallel is used """ def __init__(self, in_features, out_features, scale=8.0, margin=0.2, mode='cosface', use_cifp=False, reduction='mean'): """ Args: in_features: size of each input features out_features: size of each output features scale: norm of input feature margin: margin """ super(FC_ddp, self).__init__() self.in_features = in_features self.out_features = out_features # num of classes self.scale = scale self.margin = margin self.mode = mode self.use_cifp = use_cifp # self.kernel = Parameter(torch.Tensor(in_features, out_features)) # nn.init.normal_(self.kernel, std=0.01) self.criteria = torch.nn.CrossEntropyLoss(reduction=reduction) self.sig = torch.nn.Sigmoid() def apply_margin(self, target_cos_theta): assert self.mode in ['cosface', 'arcface'], 'Please check the mode' if self.mode == 'arcface': cos_m = math.cos(self.margin) sin_m = math.sin(self.margin) theta = math.cos(math.pi - self.margin) sinmm = math.sin(math.pi - self.margin) * self.margin sin_theta = torch.sqrt(1.0 - torch.pow(target_cos_theta, 2)) cos_theta_m = target_cos_theta * cos_m - sin_theta * sin_m target_cos_theta_m = torch.where( target_cos_theta > theta, cos_theta_m, target_cos_theta - sinmm) elif self.mode == 'cosface': target_cos_theta_m = target_cos_theta - self.margin return target_cos_theta_m def forward(self, embeddings, label, return_logits=False): """ :param embeddings: local gpu [bs, 512] :param label: local labels [bs] :param return_logits: bool :return: loss: computed local loss, w/wo CIFP acc: local accuracy in one gpu output: local logits with margins, with gradients, scaled, [bs, C]. """ sample_num = embeddings.size(0) cos_theta = self.sig(embeddings) target_cos_theta = cos_theta[torch.arange(0, sample_num), label].view(-1, 1) # target_cos_theta_m = target_cos_theta - self.margin target_cos_theta = target_cos_theta - self.margin # cos_theta.scatter_(1, label.view(-1, 1).long(), target_cos_theta_m) out = cos_theta.clone() out.scatter_(1, label.view(-1, 1).long(), target_cos_theta) out = out * self.scale loss = self.criteria(out, label) return loss