Spaces:
Configuration error
Configuration error
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import paddle | |
from paddle import nn | |
import paddle.nn.functional as F | |
from paddleseg.cvlibs import manager | |
class PointCrossEntropyLoss(nn.Layer): | |
""" | |
Implements the point cross entropy loss function. | |
The original article refers to | |
Kirillov A, Wu Y, He K, et al. "PointRend: Image Segmentation As Rendering." | |
(https://arxiv.org/abs/1912.08193). | |
Args: | |
weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight | |
given to each class. Its length must be equal to the number of classes. | |
Default ``None``. | |
ignore_index (int64, optional): Specifies a target value that is ignored | |
and does not contribute to the input gradient. Default ``255``. | |
top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0]. When its value < 1.0, only compute the loss for | |
the top k percent pixels (e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``. | |
data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``. | |
""" | |
def __init__(self, | |
weight=None, | |
ignore_index=255, | |
top_k_percent_pixels=1.0, | |
data_format='NCHW', | |
align_corners=False): | |
super(PointCrossEntropyLoss, self).__init__() | |
if weight is not None: | |
weight = paddle.to_tensor(weight, dtype='float32') | |
self.weight = weight | |
self.ignore_index = ignore_index | |
self.top_k_percent_pixels = top_k_percent_pixels | |
self.EPS = 1e-8 | |
self.data_format = data_format | |
self.align_corners = align_corners | |
def forward(self, logits, label, semantic_weights=None): | |
""" | |
Forward computation. | |
Args: | |
logits (Tensor): Logit tensor, the data type is float32, float64. Shape is | |
(logit,points). logit'shape: [N, C, point_num]. logit'shape:[N, point_num, 2], where C is number of classes. | |
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each | |
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is | |
(N, D1, D2,..., Dk), k >= 1. | |
semantic_weights (Tensor, optional): Weights about loss for each pixels, shape is the same as label. Default: None. | |
""" | |
# for loss | |
logit, points = logits # [N, C, point_num],[N, point_num, 2] | |
label = label.unsqueeze(1) # [N,1,H,W] | |
label = point_sample( | |
label.astype('float32'), | |
points, | |
mode='nearest', | |
align_corners=self.align_corners) # [N, 1, point_num] | |
label = paddle.squeeze(label, axis=1).astype('int64') # [N, xx] | |
channel_axis = 1 if self.data_format == 'NCHW' else -1 | |
if self.weight is not None and logit.shape[channel_axis] != len( | |
self.weight): | |
raise ValueError( | |
'The number of weights = {} must be the same as the number of classes = {}.' | |
.format(len(self.weight), logit.shape[1])) | |
logit = paddle.transpose(logit, [0, 2, 1]) | |
no_ignore_label = label | |
#no_ignore_label[label==self.ignore_index] = 0 | |
loss = F.cross_entropy( | |
logit, | |
no_ignore_label, | |
ignore_index=self.ignore_index, | |
reduction='none') | |
mask = label != self.ignore_index | |
mask = paddle.cast(mask, 'float32') | |
loss = loss * mask | |
if semantic_weights is not None: | |
loss = loss * semantic_weights | |
if self.weight is not None: | |
_one_hot = F.one_hot(label, logit.shape[-1]) | |
_one_hot_weight = _one_hot * self.weight | |
loss = loss * _one_hot_weight.argmax(-1) | |
coef = paddle.sum(_one_hot_weight, axis=-1) | |
#coef = paddle.ones_like(label) | |
else: | |
coef = paddle.ones_like(label) | |
label.stop_gradient = True | |
mask.stop_gradient = True | |
if self.top_k_percent_pixels == 1.0: | |
avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS) | |
return avg_loss | |
loss = loss.reshape((-1, )) | |
top_k_pixels = int(self.top_k_percent_pixels * loss.numel()) | |
loss, indices = paddle.topk(loss, top_k_pixels) | |
coef = coef.reshape((-1, )) | |
coef = paddle.gather(coef, indices) | |
coef.stop_gradient = True | |
return loss.mean() / (paddle.mean(coef) + self.EPS) | |
def point_sample(input, points, align_corners=False, **kwargs): | |
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors | |
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to | |
lie inside ``[0, 1] x [0, 1]`` square. | |
Args: | |
input (Tensor): Feature map, shape (N, C, H, W). | |
points (Tensor): Image based absolute point coordinates (normalized), | |
range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2). | |
align_corners (bool): Whether align_corners. Default: False | |
Returns: | |
Tensor: Features of `point` on `input`, shape (N, C, P) or | |
(N, C, Hgrid, Wgrid). | |
""" | |
def denormalize(grid): | |
"""Denormalize input grid from range [0, 1] to [-1, 1] | |
Args: | |
grid (Tensor): The grid to be denormalize, range [0, 1]. | |
Returns: | |
Tensor: Denormalized grid, range [-1, 1]. | |
""" | |
return grid * 2.0 - 1.0 | |
add_dim = False | |
if points.dim() == 3: | |
add_dim = True | |
points = paddle.unsqueeze(points, axis=2) # [2, 2048, 1, 2] | |
output = F.grid_sample( | |
input, denormalize(points), align_corners=align_corners, **kwargs) | |
if add_dim: | |
output = paddle.squeeze(output, axis=3) | |
return output | |