Spaces:
Runtime error
Runtime error
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | |
# | |
#Licensed under the Apache License, Version 2.0 (the "License"); | |
#you may not use this file except in compliance with the License. | |
#You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
#Unless required by applicable law or agreed to in writing, software | |
#distributed under the License is distributed on an "AS IS" BASIS, | |
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
#See the License for the specific language governing permissions and | |
#limitations under the License. | |
# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import pickle | |
import paddle | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
class CenterLoss(nn.Layer): | |
""" | |
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. | |
""" | |
def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None): | |
super().__init__() | |
self.num_classes = num_classes | |
self.feat_dim = feat_dim | |
self.centers = paddle.randn( | |
shape=[self.num_classes, self.feat_dim]).astype("float64") | |
if center_file_path is not None: | |
assert os.path.exists( | |
center_file_path | |
), f"center path({center_file_path}) must exist when it is not None." | |
with open(center_file_path, 'rb') as f: | |
char_dict = pickle.load(f) | |
for key in char_dict.keys(): | |
self.centers[key] = paddle.to_tensor(char_dict[key]) | |
def __call__(self, predicts, batch): | |
assert isinstance(predicts, (list, tuple)) | |
features, predicts = predicts | |
feats_reshape = paddle.reshape( | |
features, [-1, features.shape[-1]]).astype("float64") | |
label = paddle.argmax(predicts, axis=2) | |
label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) | |
batch_size = feats_reshape.shape[0] | |
#calc l2 distance between feats and centers | |
square_feat = paddle.sum(paddle.square(feats_reshape), | |
axis=1, | |
keepdim=True) | |
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes]) | |
square_center = paddle.sum(paddle.square(self.centers), | |
axis=1, | |
keepdim=True) | |
square_center = paddle.expand( | |
square_center, [self.num_classes, batch_size]).astype("float64") | |
square_center = paddle.transpose(square_center, [1, 0]) | |
distmat = paddle.add(square_feat, square_center) | |
feat_dot_center = paddle.matmul(feats_reshape, | |
paddle.transpose(self.centers, [1, 0])) | |
distmat = distmat - 2.0 * feat_dot_center | |
#generate the mask | |
classes = paddle.arange(self.num_classes).astype("int64") | |
label = paddle.expand( | |
paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) | |
mask = paddle.equal( | |
paddle.expand(classes, [batch_size, self.num_classes]), | |
label).astype("float64") | |
dist = paddle.multiply(distmat, mask) | |
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size | |
return {'loss_center': loss} | |