# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import paddle.nn as nn import paddle.nn.functional as F from paddleseg.cvlibs import manager @manager.LOSSES.add_component class KLLoss(nn.Layer): """ The implementation of Kullback-Leibler divergence Loss. Refer to https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence. Args: ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``. temperature (float): the coefficient of kl_loss. """ def __init__(self, ignore_index=255, temperature=1): super().__init__() self.ignore_index = ignore_index self.temperature = temperature self.kl_loss = nn.KLDivLoss(reduction="none") self.EPS = 1e-8 def forward(self, logit_1, logit_2, label=None): """ Calculate the KL loss. If the label is not None, it considers the ignore_index in label and calculates the masked loss. Args: logit_1 (Tensor): Logit tensor, the data type is float32 or float64. The shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1. logit_2 (Tensor): Logit tensor, the data type is float32 or float64. The shape of logit_2 and logit_1 are the same. label (Tensor, optional): Label tensor, the data type is int64. The shape is (N), where each value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is (N, D1, D2,..., Dk), k >= 1. Returns: (Tensor): The average loss. """ if logit_1.shape != logit_2.shape: raise ValueError( 'The shape of logit_1 = {} must be the same as the shape of logit_2 = {}.' .format(logit_1.shape, logit_2.shape)) logit_1 = F.log_softmax(logit_1 / self.temperature, axis=1) logit_2 = F.softmax(logit_2 / self.temperature, axis=1) loss = self.kl_loss(logit_1, logit_2) loss = loss * self.temperature * self.temperature if label is None: avg_loss = paddle.mean(loss) else: mask = label != self.ignore_index mask = paddle.cast(mask, 'float32') mask = paddle.unsqueeze(mask, axis=1) label.stop_gradient = True mask.stop_gradient = True loss = loss * mask avg_loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS) return avg_loss