import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from collections import defaultdict |
class MattingCriterion(nn.Module): |
def __init__( |
self, |
*, |
losses, |
image_size = 1024, |
): |
super(MattingCriterion, self).__init__() |
self.losses = losses |
self.image_size = image_size |
def loss_gradient_penalty(self, sample_map, preds, targets): |
if torch.sum(sample_map) == 0: |
scale = 0 |
else: |
scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map) |
sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) |
delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) |
delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) |
sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) |
delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) |
delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) |
loss = (F.l1_loss(delta_pred_x * sample_map, delta_gt_x * sample_map) * scale + \ |
F.l1_loss(delta_pred_y * sample_map, delta_gt_y * sample_map) * scale + \ |
0.01 * torch.mean(torch.abs(delta_pred_x * sample_map)) * scale + \ |
0.01 * torch.mean(torch.abs(delta_pred_y * sample_map)) * scale) |
return dict(loss_gradient_penalty=loss) |
def loss_pha_laplacian(self, preds, targets): |
loss = laplacian_loss(preds, targets) |
return dict(loss_pha_laplacian=loss) |
def unknown_l1_loss(self, sample_map, preds, targets): |
if torch.sum(sample_map) == 0: |
scale = 0 |
else: |
scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map) |
loss = F.l1_loss(preds * sample_map, targets * sample_map) * scale |
return dict(unknown_l1_loss=loss) |
def known_l1_loss(self, sample_map, preds, targets): |
new_sample_map = torch.zeros_like(sample_map) |
new_sample_map[sample_map==0] = 1 |
if torch.sum(new_sample_map) == 0: |
scale = 0 |
else: |
scale = new_sample_map.shape[0] * (self.image_size ** 2) / torch.sum(new_sample_map) |
loss = F.l1_loss(preds * new_sample_map, targets * new_sample_map) * scale |
return dict(known_l1_loss=loss) |
def get_loss(self, k, sample_map, preds, targets): |
if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty': |
losses = getattr(self, k)(sample_map, preds, targets) |
else: |
losses = getattr(self, k)(preds, targets) |
assert len(list(losses.keys())) == 1 |
return losses[list(losses.keys())[0]] |
def forward(self, sample_map, preds, targets, batch_weight=None): |
losses = {i: torch.tensor(0.0, device=sample_map.device) for i in self.losses} |
for k in self.losses: |
if batch_weight is None: |
losses[k] += self.get_loss(k, sample_map, preds, targets) |
else: |
for i, loss_weight in enumerate(batch_weight): |
if loss_weight == -1.0 and k != 'known_l1_loss': |
continue |
else: |
losses[k] += self.get_loss(k, sample_map[i: i + 1], preds[i: i + 1], targets[i: i + 1]) * abs(loss_weight) |
return losses |
def laplacian_loss(pred, true, max_levels=5): |
kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) |
pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) |
true_pyramid = laplacian_pyramid(true, kernel, max_levels) |
loss = 0 |
for level in range(max_levels): |
loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) |
return loss / max_levels |
def laplacian_pyramid(img, kernel, max_levels): |
current = img |
pyramid = [] |
for _ in range(max_levels): |
current = crop_to_even_size(current) |
down = downsample(current, kernel) |
up = upsample(down, kernel) |
diff = current - up |
pyramid.append(diff) |
current = down |
return pyramid |
def gauss_kernel(device='cpu', dtype=torch.float32): |
kernel = torch.tensor([[1, 4, 6, 4, 1], |
[4, 16, 24, 16, 4], |
[6, 24, 36, 24, 6], |
[4, 16, 24, 16, 4], |
[1, 4, 6, 4, 1]], device=device, dtype=dtype) |
kernel /= 256 |
kernel = kernel[None, None, :, :] |
return kernel |
def gauss_convolution(img, kernel): |
B, C, H, W = img.shape |
img = img.reshape(B * C, 1, H, W) |
img = F.pad(img, (2, 2, 2, 2), mode='reflect') |
img = F.conv2d(img, kernel) |
img = img.reshape(B, C, H, W) |
return img |
def downsample(img, kernel): |
img = gauss_convolution(img, kernel) |
img = img[:, :, ::2, ::2] |
return img |
def upsample(img, kernel): |
B, C, H, W = img.shape |
out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) |
out[:, :, ::2, ::2] = img * 4 |
out = gauss_convolution(out, kernel) |
return out |
def crop_to_even_size(img): |
H, W = img.shape[2:] |
H = H - H % 2 |
W = W - W % 2 |
return img[:, :, :H, :W] |
def normalized_focal_loss(pred, gt, gamma=2, class_num=3, norm=True, beta_detach=False, beta_sum_detach=False): |
pred_logits = F.softmax(pred, dim=1) |
gt_one_hot = F.one_hot(gt, class_num).permute(0, 3, 1, 2) |
p = (pred_logits * gt_one_hot).sum(dim=1) |
beta = (1 - p) ** gamma |
beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) / (pred.shape[-1] * pred.shape[-2]) |
if beta_detach: |
beta = beta.detach() |
if beta_sum_detach: |
beta_sum = beta_sum.detach() |
if norm: |
loss = 1 / beta_sum * beta * (-torch.log(p)) |
return torch.mean(loss) |
else: |
loss = beta * (-torch.log(p)) |
return torch.mean(loss) |
class GHMC(nn.Module): |
def __init__(self, bins=10, momentum=0.75, loss_weight=1.0, device='cuda', norm=False): |
super(GHMC, self).__init__() |
self.bins = bins |
self.momentum = momentum |
self.edges = torch.arange(bins + 1).float().cuda() / bins |
self.edges[-1] += 1e-6 |
if momentum > 0: |
self.acc_sum = torch.zeros(bins).cuda() |
self.loss_weight = loss_weight |
self.device = device |
self.norm = norm |
def forward(self, pred, target, *args, **kwargs): |
"""Calculate the GHM-C loss. |
Args: |
pred (float tensor of size [batch_num, class_num]): |
The direct prediction of classification fc layer. |
target (float tensor of size [batch_num, class_num]): |
Binary class target for each sample. |
label_weight (float tensor of size [batch_num, class_num]): |
the value is 1 if the sample is valid and 0 if ignored. |
Returns: |
The gradient harmonized loss. |
""" |
pred = pred.permute(0, 2, 3, 1).reshape(-1, 3) |
target = target.reshape(-1) |
edges = self.edges |
mmt = self.momentum |
weights = torch.zeros((target.shape),dtype=pred.dtype).to(self.device) |
g = 1 - torch.gather(F.softmax(pred,dim=1).detach(),dim=1,index=target.unsqueeze(1)) |
tot = 1.0 |
n = 0 |
for i in range(self.bins): |
inds = (g >= edges[i]) & (g < edges[i+1]) |
num_in_bin = inds.sum().item() |
if num_in_bin > 0: |
idx = torch.nonzero(inds)[:, 0] |
if mmt > 0: |
self.acc_sum[i] = mmt * self.acc_sum[i] \ |
+ (1 - mmt) * num_in_bin |
_weight_idx = tot / self.acc_sum[i] |
weights = weights.to(dtype=_weight_idx.dtype) |
weights[idx] = _weight_idx |
else: |
weights[idx] = tot / num_in_bin |
n += 1 |
if n > 0: |
weights = weights / n |
if self.norm: |
weights = weights / torch.sum(weights).detach() |
loss = - ((weights.unsqueeze(1) * torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() ) |
return loss |
if __name__ == '__main__': |
pred = torch.randn(2, 3, 1024, 1024) |
gt =torch.argmax(torch.randn(2, 3, 1024, 1024), dim=1) |
loss = normalized_focal_loss(pred, gt) |
print(loss) |