|
|
|
|
|
|
|
""" |
|
@Author : Peike Li |
|
@Contact : [email protected] |
|
@File : kl_loss.py |
|
@Time : 7/23/19 4:02 PM |
|
@Desc : |
|
@License : This source code is licensed under the license found in the |
|
LICENSE file in the root directory of this source tree. |
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from datasets.target_generation import generate_edge_tensor |
|
|
|
|
|
class ConsistencyLoss(nn.Module): |
|
def __init__(self, ignore_index=255): |
|
super(ConsistencyLoss, self).__init__() |
|
self.ignore_index=ignore_index |
|
|
|
def forward(self, parsing, edge, label): |
|
parsing_pre = torch.argmax(parsing, dim=1) |
|
parsing_pre[label==self.ignore_index]=self.ignore_index |
|
generated_edge = generate_edge_tensor(parsing_pre) |
|
edge_pre = torch.argmax(edge, dim=1) |
|
v_generate_edge = generated_edge[label!=255] |
|
v_edge_pre = edge_pre[label!=255] |
|
v_edge_pre = v_edge_pre.type(torch.cuda.FloatTensor) |
|
positive_union = (v_generate_edge==1)&(v_edge_pre==1) |
|
return F.smooth_l1_loss(v_generate_edge[positive_union].squeeze(0), v_edge_pre[positive_union].squeeze(0)) |
|
|