Spaces:
Configuration error
Configuration error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numpy as np | |
import paddle | |
from paddle import nn | |
import paddle.nn.functional as F | |
from scipy.ndimage import shift | |
from paddleseg.cvlibs import manager | |
class RelaxBoundaryLoss(nn.Layer): | |
""" | |
Implements the ohem cross entropy loss function. | |
Args: | |
border (int, optional): The value of border to relax. Default: 1. | |
calculate_weights (bool, optional): Whether to calculate weights for every classes. Default: False. | |
upper_bound (float, optional): The upper bound of weights if calculating weights for every classes. Default: 1.0. | |
ignore_index (int64): Specifies a target value that is ignored | |
and does not contribute to the input gradient. Default: 255. | |
""" | |
def __init__(self, | |
border=1, | |
calculate_weights=False, | |
upper_bound=1.0, | |
ignore_index=255): | |
super(RelaxBoundaryLoss, self).__init__() | |
self.border = border | |
self.calculate_weights = calculate_weights | |
self.upper_bound = upper_bound | |
self.ignore_index = ignore_index | |
self.EPS = 1e-5 | |
def relax_onehot(self, label, num_classes): | |
# pad label, and let ignore_index as num_classes | |
if len(label.shape) == 3: | |
label = label.unsqueeze(1) | |
h, w = label.shape[-2], label.shape[-1] | |
label = F.pad(label, [self.border] * 4, value=num_classes) | |
label = label.squeeze(1) | |
ignore_mask = (label == self.ignore_index).astype('int64') | |
label = label * (1 - ignore_mask) + num_classes * ignore_mask | |
onehot = 0 | |
for i in range(-self.border, self.border + 1): | |
for j in range(-self.border, self.border + 1): | |
h_start, h_end = 1 + i, h + 1 + i | |
w_start, w_end = 1 + j, w + 1 + j | |
label_ = label[:, h_start:h_end, w_start:w_end] | |
onehot_ = F.one_hot(label_, num_classes + 1) | |
onehot += onehot_ | |
onehot = (onehot > 0).astype('int64') | |
onehot = paddle.transpose(onehot, (0, 3, 1, 2)) | |
return onehot | |
def calculate_weights(self, label): | |
hist = paddle.sum(label, axis=(1, 2)) * 1.0 / label.sum() | |
hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 | |
def custom_nll(self, | |
logit, | |
label, | |
class_weights=None, | |
border_weights=None, | |
ignore_mask=None): | |
soft = F.softmax(logit, axis=1) | |
# calculate the valid soft where label is 1. | |
soft_label = ((soft * label[:, :-1, :, :]).sum( | |
1, keepdim=True)) * (label[:, :-1, :, :].astype('float32')) | |
soft = soft * (1 - label[:, :-1, :, :]) + soft_label | |
logsoft = paddle.log(soft) | |
if class_weights is not None: | |
logsoft = class_weights.unsqueeze((0, 2, 3)) | |
logsoft = label[:, :-1, :, :] * logsoft | |
logsoft = logsoft.sum(1) | |
# border loss is divided equally | |
logsoft = -1 / border_weights * logsoft * (1. - ignore_mask) | |
n, _, h, w = label.shape | |
logsoft = logsoft.sum() / (n * h * w - ignore_mask.sum() + 1) | |
return logsoft | |
def forward(self, logit, label): | |
""" | |
Forward computation. | |
Args: | |
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is | |
(N, C), where C is number of classes, and if shape is more than 2D, this | |
is (N, C, D1, D2,..., Dk), k >= 1. | |
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each | |
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is | |
(N, D1, D2,..., Dk), k >= 1. | |
""" | |
n, c, h, w = logit.shape | |
label.stop_gradient = True | |
label = self.relax_onehot(label, c) | |
weights = label[:, :-1, :, :].sum(1).astype('float32') | |
ignore_mask = (weights == 0).astype('float32') | |
# border is greater than 1, other is 1 | |
border_weights = weights + ignore_mask | |
loss = 0 | |
class_weights = None | |
for i in range(n): | |
if self.calculate_weights: | |
class_weights = self.calculate_weights(label[i]) | |
loss = loss + self.custom_nll( | |
logit[i].unsqueeze(0), | |
label[i].unsqueeze(0), | |
class_weights=class_weights, | |
border_weights=border_weights, | |
ignore_mask=ignore_mask[i]) | |
return loss | |