Spaces:
Configuration error
Configuration error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import paddle | |
from paddle import nn | |
import paddle.nn.functional as F | |
from paddleseg.cvlibs import manager | |
class CrossEntropyLoss(nn.Layer): | |
""" | |
Implements the cross entropy loss function. | |
Args: | |
weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight | |
given to each class. Its length must be equal to the number of classes. | |
Default ``None``. | |
ignore_index (int64, optional): Specifies a target value that is ignored | |
and does not contribute to the input gradient. Default ``255``. | |
top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0]. | |
When its value < 1.0, only compute the loss for the top k percent pixels | |
(e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``. | |
data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``. | |
""" | |
def __init__(self, | |
weight=None, | |
ignore_index=255, | |
top_k_percent_pixels=1.0, | |
data_format='NCHW'): | |
super(CrossEntropyLoss, self).__init__() | |
self.ignore_index = ignore_index | |
self.top_k_percent_pixels = top_k_percent_pixels | |
self.EPS = 1e-8 | |
self.data_format = data_format | |
if weight is not None: | |
self.weight = paddle.to_tensor(weight, dtype='float32') | |
else: | |
self.weight = None | |
def forward(self, logit, label, semantic_weights=None): | |
""" | |
Forward computation. | |
Args: | |
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is | |
(N, C), where C is number of classes, and if shape is more than 2D, this | |
is (N, C, D1, D2,..., Dk), k >= 1. | |
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each | |
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is | |
(N, D1, D2,..., Dk), k >= 1. | |
semantic_weights (Tensor, optional): Weights about loss for each pixels, | |
shape is the same as label. Default: None. | |
Returns: | |
(Tensor): The average loss. | |
""" | |
channel_axis = 1 if self.data_format == 'NCHW' else -1 | |
if self.weight is not None and logit.shape[channel_axis] != len( | |
self.weight): | |
raise ValueError( | |
'The number of weights = {} must be the same as the number of classes = {}.' | |
.format(len(self.weight), logit.shape[channel_axis])) | |
if channel_axis == 1: | |
logit = paddle.transpose(logit, [0, 2, 3, 1]) | |
label = label.astype('int64') | |
loss = F.cross_entropy( | |
logit, | |
label, | |
ignore_index=self.ignore_index, | |
reduction='none', | |
weight=self.weight) | |
return self._post_process_loss(logit, label, semantic_weights, loss) | |
def _post_process_loss(self, logit, label, semantic_weights, loss): | |
""" | |
Consider mask and top_k to calculate the final loss. | |
Args: | |
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is | |
(N, C), where C is number of classes, and if shape is more than 2D, this | |
is (N, C, D1, D2,..., Dk), k >= 1. | |
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each | |
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is | |
(N, D1, D2,..., Dk), k >= 1. | |
semantic_weights (Tensor, optional): Weights about loss for each pixels, | |
shape is the same as label. | |
loss (Tensor): Loss tensor which is the output of cross_entropy. If soft_label | |
is False in cross_entropy, the shape of loss should be the same as the label. | |
If soft_label is True in cross_entropy, the shape of loss should be | |
(N, D1, D2,..., Dk, 1). | |
Returns: | |
(Tensor): The average loss. | |
""" | |
mask = label != self.ignore_index | |
mask = paddle.cast(mask, 'float32') | |
label.stop_gradient = True | |
mask.stop_gradient = True | |
if loss.ndim > mask.ndim: | |
loss = paddle.squeeze(loss, axis=-1) | |
loss = loss * mask | |
if semantic_weights is not None: | |
loss = loss * semantic_weights | |
if self.weight is not None: | |
_one_hot = F.one_hot(label * mask, logit.shape[-1]) | |
coef = paddle.sum(_one_hot * self.weight, axis=-1) | |
else: | |
coef = paddle.ones_like(label) | |
if self.top_k_percent_pixels == 1.0: | |
avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS) | |
else: | |
loss = loss.reshape((-1, )) | |
top_k_pixels = int(self.top_k_percent_pixels * loss.numel()) | |
loss, indices = paddle.topk(loss, top_k_pixels) | |
coef = coef.reshape((-1, )) | |
coef = paddle.gather(coef, indices) | |
coef.stop_gradient = True | |
coef = coef.astype('float32') | |
avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS) | |
return avg_loss | |
class DistillCrossEntropyLoss(CrossEntropyLoss): | |
""" | |
The implementation of distill cross entropy loss. | |
Args: | |
weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight | |
given to each class. Its length must be equal to the number of classes. | |
Default ``None``. | |
ignore_index (int64, optional): Specifies a target value that is ignored | |
and does not contribute to the input gradient. Default ``255``. | |
top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0]. | |
When its value < 1.0, only compute the loss for the top k percent pixels | |
(e.g., the top 20% pixels). This is useful for hard pixel mining. | |
Default ``1.0``. | |
data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. | |
Default ``'NCHW'``. | |
""" | |
def __init__(self, | |
weight=None, | |
ignore_index=255, | |
top_k_percent_pixels=1.0, | |
data_format='NCHW'): | |
super().__init__(weight, ignore_index, top_k_percent_pixels, | |
data_format) | |
def forward(self, | |
student_logit, | |
teacher_logit, | |
label, | |
semantic_weights=None): | |
""" | |
Forward computation. | |
Args: | |
student_logit (Tensor): Logit tensor, the data type is float32, float64. Shape is | |
(N, C), where C is number of classes, and if shape is more than 2D, this | |
is (N, C, D1, D2,..., Dk), k >= 1. | |
teacher_logit (Tensor): Logit tensor, the data type is float32, float64. The shape | |
is the same as the student_logit. | |
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each | |
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is | |
(N, D1, D2,..., Dk), k >= 1. | |
semantic_weights (Tensor, optional): Weights about loss for each pixels, | |
shape is the same as label. Default: None. | |
""" | |
if student_logit.shape != teacher_logit.shape: | |
raise ValueError( | |
'The shape of student_logit = {} must be the same as the shape of teacher_logit = {}.' | |
.format(student_logit.shape, teacher_logit.shape)) | |
channel_axis = 1 if self.data_format == 'NCHW' else -1 | |
if self.weight is not None and student_logit.shape[channel_axis] != len( | |
self.weight): | |
raise ValueError( | |
'The number of weights = {} must be the same as the number of classes = {}.' | |
.format(len(self.weight), student_logit.shape[channel_axis])) | |
if channel_axis == 1: | |
student_logit = paddle.transpose(student_logit, [0, 2, 3, 1]) | |
teacher_logit = paddle.transpose(teacher_logit, [0, 2, 3, 1]) | |
teacher_logit = F.softmax(teacher_logit) | |
loss = F.cross_entropy( | |
student_logit, | |
teacher_logit, | |
weight=self.weight, | |
reduction='none', | |
soft_label=True) | |
return self._post_process_loss(student_logit, label, semantic_weights, | |
loss) | |