|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from collections import defaultdict |
|
|
|
|
|
class MattingCriterion(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
losses, |
|
image_size = 1024, |
|
): |
|
super(MattingCriterion, self).__init__() |
|
self.losses = losses |
|
self.image_size = image_size |
|
|
|
def loss_gradient_penalty(self, sample_map, preds, targets): |
|
|
|
|
|
if torch.sum(sample_map) == 0: |
|
scale = 0 |
|
else: |
|
scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map) |
|
|
|
|
|
sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) |
|
delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) |
|
delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) |
|
|
|
|
|
sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) |
|
delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) |
|
delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) |
|
|
|
|
|
loss = (F.l1_loss(delta_pred_x * sample_map, delta_gt_x * sample_map) * scale + \ |
|
F.l1_loss(delta_pred_y * sample_map, delta_gt_y * sample_map) * scale + \ |
|
0.01 * torch.mean(torch.abs(delta_pred_x * sample_map)) * scale + \ |
|
0.01 * torch.mean(torch.abs(delta_pred_y * sample_map)) * scale) |
|
|
|
return dict(loss_gradient_penalty=loss) |
|
|
|
def loss_pha_laplacian(self, preds, targets): |
|
loss = laplacian_loss(preds, targets) |
|
return dict(loss_pha_laplacian=loss) |
|
|
|
def unknown_l1_loss(self, sample_map, preds, targets): |
|
|
|
if torch.sum(sample_map) == 0: |
|
scale = 0 |
|
else: |
|
scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map) |
|
|
|
|
|
loss = F.l1_loss(preds * sample_map, targets * sample_map) * scale |
|
|
|
return dict(unknown_l1_loss=loss) |
|
|
|
def known_l1_loss(self, sample_map, preds, targets): |
|
new_sample_map = torch.zeros_like(sample_map) |
|
new_sample_map[sample_map==0] = 1 |
|
|
|
if torch.sum(new_sample_map) == 0: |
|
scale = 0 |
|
else: |
|
scale = new_sample_map.shape[0] * (self.image_size ** 2) / torch.sum(new_sample_map) |
|
|
|
|
|
loss = F.l1_loss(preds * new_sample_map, targets * new_sample_map) * scale |
|
|
|
return dict(known_l1_loss=loss) |
|
|
|
def get_loss(self, k, sample_map, preds, targets): |
|
if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty': |
|
losses = getattr(self, k)(sample_map, preds, targets) |
|
else: |
|
losses = getattr(self, k)(preds, targets) |
|
assert len(list(losses.keys())) == 1 |
|
return losses[list(losses.keys())[0]] |
|
|
|
def forward(self, sample_map, preds, targets, batch_weight=None): |
|
losses = {i: torch.tensor(0.0, device=sample_map.device) for i in self.losses} |
|
for k in self.losses: |
|
if batch_weight is None: |
|
losses[k] += self.get_loss(k, sample_map, preds, targets) |
|
else: |
|
for i, loss_weight in enumerate(batch_weight): |
|
if loss_weight == -1.0 and k != 'known_l1_loss': |
|
continue |
|
else: |
|
losses[k] += self.get_loss(k, sample_map[i: i + 1], preds[i: i + 1], targets[i: i + 1]) * abs(loss_weight) |
|
return losses |
|
|
|
|
|
|
|
def laplacian_loss(pred, true, max_levels=5): |
|
kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) |
|
pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) |
|
true_pyramid = laplacian_pyramid(true, kernel, max_levels) |
|
loss = 0 |
|
for level in range(max_levels): |
|
loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) |
|
return loss / max_levels |
|
|
|
def laplacian_pyramid(img, kernel, max_levels): |
|
current = img |
|
pyramid = [] |
|
for _ in range(max_levels): |
|
current = crop_to_even_size(current) |
|
down = downsample(current, kernel) |
|
up = upsample(down, kernel) |
|
diff = current - up |
|
pyramid.append(diff) |
|
current = down |
|
return pyramid |
|
|
|
def gauss_kernel(device='cpu', dtype=torch.float32): |
|
kernel = torch.tensor([[1, 4, 6, 4, 1], |
|
[4, 16, 24, 16, 4], |
|
[6, 24, 36, 24, 6], |
|
[4, 16, 24, 16, 4], |
|
[1, 4, 6, 4, 1]], device=device, dtype=dtype) |
|
kernel /= 256 |
|
kernel = kernel[None, None, :, :] |
|
return kernel |
|
|
|
def gauss_convolution(img, kernel): |
|
B, C, H, W = img.shape |
|
img = img.reshape(B * C, 1, H, W) |
|
img = F.pad(img, (2, 2, 2, 2), mode='reflect') |
|
img = F.conv2d(img, kernel) |
|
img = img.reshape(B, C, H, W) |
|
return img |
|
|
|
def downsample(img, kernel): |
|
img = gauss_convolution(img, kernel) |
|
img = img[:, :, ::2, ::2] |
|
return img |
|
|
|
def upsample(img, kernel): |
|
B, C, H, W = img.shape |
|
out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) |
|
out[:, :, ::2, ::2] = img * 4 |
|
out = gauss_convolution(out, kernel) |
|
return out |
|
|
|
def crop_to_even_size(img): |
|
H, W = img.shape[2:] |
|
H = H - H % 2 |
|
W = W - W % 2 |
|
return img[:, :, :H, :W] |
|
|
|
def normalized_focal_loss(pred, gt, gamma=2, class_num=3, norm=True, beta_detach=False, beta_sum_detach=False): |
|
pred_logits = F.softmax(pred, dim=1) |
|
gt_one_hot = F.one_hot(gt, class_num).permute(0, 3, 1, 2) |
|
p = (pred_logits * gt_one_hot).sum(dim=1) |
|
beta = (1 - p) ** gamma |
|
beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) / (pred.shape[-1] * pred.shape[-2]) |
|
|
|
if beta_detach: |
|
beta = beta.detach() |
|
if beta_sum_detach: |
|
beta_sum = beta_sum.detach() |
|
|
|
if norm: |
|
loss = 1 / beta_sum * beta * (-torch.log(p)) |
|
return torch.mean(loss) |
|
else: |
|
loss = beta * (-torch.log(p)) |
|
return torch.mean(loss) |
|
|
|
class GHMC(nn.Module): |
|
def __init__(self, bins=10, momentum=0.75, loss_weight=1.0, device='cuda', norm=False): |
|
super(GHMC, self).__init__() |
|
self.bins = bins |
|
self.momentum = momentum |
|
self.edges = torch.arange(bins + 1).float().cuda() / bins |
|
self.edges[-1] += 1e-6 |
|
if momentum > 0: |
|
self.acc_sum = torch.zeros(bins).cuda() |
|
self.loss_weight = loss_weight |
|
self.device = device |
|
self.norm = norm |
|
|
|
def forward(self, pred, target, *args, **kwargs): |
|
"""Calculate the GHM-C loss. |
|
Args: |
|
pred (float tensor of size [batch_num, class_num]): |
|
The direct prediction of classification fc layer. |
|
target (float tensor of size [batch_num, class_num]): |
|
Binary class target for each sample. |
|
label_weight (float tensor of size [batch_num, class_num]): |
|
the value is 1 if the sample is valid and 0 if ignored. |
|
Returns: |
|
The gradient harmonized loss. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = pred.permute(0, 2, 3, 1).reshape(-1, 3) |
|
target = target.reshape(-1) |
|
|
|
|
|
edges = self.edges |
|
mmt = self.momentum |
|
weights = torch.zeros((target.shape),dtype=pred.dtype).to(self.device) |
|
|
|
|
|
|
|
g = 1 - torch.gather(F.softmax(pred,dim=1).detach(),dim=1,index=target.unsqueeze(1)) |
|
|
|
|
|
tot = 1.0 |
|
n = 0 |
|
for i in range(self.bins): |
|
inds = (g >= edges[i]) & (g < edges[i+1]) |
|
num_in_bin = inds.sum().item() |
|
if num_in_bin > 0: |
|
idx = torch.nonzero(inds)[:, 0] |
|
if mmt > 0: |
|
self.acc_sum[i] = mmt * self.acc_sum[i] \ |
|
+ (1 - mmt) * num_in_bin |
|
|
|
|
|
_weight_idx = tot / self.acc_sum[i] |
|
weights = weights.to(dtype=_weight_idx.dtype) |
|
weights[idx] = _weight_idx |
|
|
|
|
|
|
|
else: |
|
weights[idx] = tot / num_in_bin |
|
n += 1 |
|
if n > 0: |
|
weights = weights / n |
|
|
|
|
|
|
|
if self.norm: |
|
weights = weights / torch.sum(weights).detach() |
|
|
|
loss = - ((weights.unsqueeze(1) * torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
if __name__ == '__main__': |
|
pred = torch.randn(2, 3, 1024, 1024) |
|
gt =torch.argmax(torch.randn(2, 3, 1024, 1024), dim=1) |
|
loss = normalized_focal_loss(pred, gt) |
|
print(loss) |
|
|
|
|
|
|
|
|