shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
# -*- coding: utf-8 -*-
# @Time : 10/1/21
# @Author : GXYM
import torch
from torch import nn
import numpy as np
class SegmentLoss(nn.Module):
def __init__(self, Lambda, ratio=3, reduction='mean'):
"""Implement PSE Loss.
"""
super(SegmentLoss, self).__init__()
assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
self.Lambda = Lambda
self.ratio = ratio
self.reduction = reduction
def forward(self, outputs, labels, training_masks, th=0.5):
texts = outputs[:, -1, :, :]
kernels = outputs[:, :-1, :, :]
gt_texts = labels[:, -1, :, :]
gt_kernels = labels[:, :-1, :, :]
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
selected_masks = selected_masks.to(outputs.device)
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
loss_kernels = []
# mask0 = torch.sigmoid(texts).data.cpu().numpy()
mask0 = texts.data.cpu().numpy()
mask1 = training_masks.data.cpu().numpy()
selected_masks = ((mask0 > th) & (mask1 > th)).astype('float32')
selected_masks = torch.from_numpy(selected_masks).float()
selected_masks = selected_masks.to(outputs.device)
kernels_num = gt_kernels.size()[1]
for i in range(kernels_num):
kernel_i = kernels[:, i, :, :]
gt_kernel_i = gt_kernels[:, i, :, :]
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
loss_kernels.append(loss_kernel_i)
loss_kernels = torch.stack(loss_kernels).mean(0)
if self.reduction == 'mean':
loss_text = loss_text.mean()
loss_kernels = loss_kernels.mean()
elif self.reduction == 'sum':
loss_text = loss_text.sum()
loss_kernels = loss_kernels.sum()
loss = self.Lambda *loss_text + (1-self.Lambda)*loss_kernels
return loss_text, loss_kernels, loss
def dice_loss(self, input, target, mask):
# input = torch.sigmoid(input)
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1)
mask = mask.contiguous().view(mask.size()[0], -1)
input = input * mask
target = (target.float()) * mask
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1 - d
def ohem_single(self, score, gt_text, training_mask, th=0.5):
pos_num = (int)(np.sum(gt_text > th)) - (int)(np.sum((gt_text > th) & (training_mask <= th)))
if pos_num == 0:
# selected_mask = gt_text.copy() * 0 # may be not good
selected_mask = training_mask
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_num = (int)(np.sum(gt_text <= th))
neg_num = (int)(min(pos_num * 3, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_score = score[gt_text <= th]
# 将负样本得分从高到低排序
neg_score_sorted = np.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
# 选出 得分高的 负样本 和正样本 的 mask
selected_mask = ((score >= threshold) | (gt_text > th)) & (training_mask > th)
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
def ohem_batch(self, scores, gt_texts, training_masks):
scores = scores.data.cpu().numpy()
gt_texts = gt_texts.data.cpu().numpy()
training_masks = training_masks.data.cpu().numpy()
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
selected_masks = np.concatenate(selected_masks, 0)
selected_masks = torch.from_numpy(selected_masks).float()
return selected_masks