Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# @Time : 10/1/21 | |
# @Author : GXYM | |
import torch | |
from torch import nn | |
import numpy as np | |
class SegmentLoss(nn.Module): | |
def __init__(self, Lambda, ratio=3, reduction='mean'): | |
"""Implement PSE Loss. | |
""" | |
super(SegmentLoss, self).__init__() | |
assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']" | |
self.Lambda = Lambda | |
self.ratio = ratio | |
self.reduction = reduction | |
def forward(self, outputs, labels, training_masks, th=0.5): | |
texts = outputs[:, -1, :, :] | |
kernels = outputs[:, :-1, :, :] | |
gt_texts = labels[:, -1, :, :] | |
gt_kernels = labels[:, :-1, :, :] | |
selected_masks = self.ohem_batch(texts, gt_texts, training_masks) | |
selected_masks = selected_masks.to(outputs.device) | |
loss_text = self.dice_loss(texts, gt_texts, selected_masks) | |
loss_kernels = [] | |
# mask0 = torch.sigmoid(texts).data.cpu().numpy() | |
mask0 = texts.data.cpu().numpy() | |
mask1 = training_masks.data.cpu().numpy() | |
selected_masks = ((mask0 > th) & (mask1 > th)).astype('float32') | |
selected_masks = torch.from_numpy(selected_masks).float() | |
selected_masks = selected_masks.to(outputs.device) | |
kernels_num = gt_kernels.size()[1] | |
for i in range(kernels_num): | |
kernel_i = kernels[:, i, :, :] | |
gt_kernel_i = gt_kernels[:, i, :, :] | |
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks) | |
loss_kernels.append(loss_kernel_i) | |
loss_kernels = torch.stack(loss_kernels).mean(0) | |
if self.reduction == 'mean': | |
loss_text = loss_text.mean() | |
loss_kernels = loss_kernels.mean() | |
elif self.reduction == 'sum': | |
loss_text = loss_text.sum() | |
loss_kernels = loss_kernels.sum() | |
loss = self.Lambda *loss_text + (1-self.Lambda)*loss_kernels | |
return loss_text, loss_kernels, loss | |
def dice_loss(self, input, target, mask): | |
# input = torch.sigmoid(input) | |
input = input.contiguous().view(input.size()[0], -1) | |
target = target.contiguous().view(target.size()[0], -1) | |
mask = mask.contiguous().view(mask.size()[0], -1) | |
input = input * mask | |
target = (target.float()) * mask | |
a = torch.sum(input * target, 1) | |
b = torch.sum(input * input, 1) + 0.001 | |
c = torch.sum(target * target, 1) + 0.001 | |
d = (2 * a) / (b + c) | |
return 1 - d | |
def ohem_single(self, score, gt_text, training_mask, th=0.5): | |
pos_num = (int)(np.sum(gt_text > th)) - (int)(np.sum((gt_text > th) & (training_mask <= th))) | |
if pos_num == 0: | |
# selected_mask = gt_text.copy() * 0 # may be not good | |
selected_mask = training_mask | |
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
return selected_mask | |
neg_num = (int)(np.sum(gt_text <= th)) | |
neg_num = (int)(min(pos_num * 3, neg_num)) | |
if neg_num == 0: | |
selected_mask = training_mask | |
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
return selected_mask | |
neg_score = score[gt_text <= th] | |
# 将负样本得分从高到低排序 | |
neg_score_sorted = np.sort(-neg_score) | |
threshold = -neg_score_sorted[neg_num - 1] | |
# 选出 得分高的 负样本 和正样本 的 mask | |
selected_mask = ((score >= threshold) | (gt_text > th)) & (training_mask > th) | |
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') | |
return selected_mask | |
def ohem_batch(self, scores, gt_texts, training_masks): | |
scores = scores.data.cpu().numpy() | |
gt_texts = gt_texts.data.cpu().numpy() | |
training_masks = training_masks.data.cpu().numpy() | |
selected_masks = [] | |
for i in range(scores.shape[0]): | |
selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])) | |
selected_masks = np.concatenate(selected_masks, 0) | |
selected_masks = torch.from_numpy(selected_masks).float() | |
return selected_masks | |