import torch import torch.nn as nn import torch.nn.functional as F class MattingCriterion(nn.Module): def __init__(self, *, losses, ): super(MattingCriterion, self).__init__() self.losses = losses def loss_gradient_penalty(self, sample_map ,preds, targets): preds = preds['phas'] targets = targets['phas'] #sample_map for unknown area scale = sample_map.shape[0]*262144/torch.sum(sample_map) #gradient in x sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) #gradient in y sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) #loss loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \ F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \ 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \ 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale) return dict(loss_gradient_penalty=loss) def loss_pha_laplacian(self, preds, targets): assert 'phas' in preds and 'phas' in targets loss = laplacian_loss(preds['phas'], targets['phas']) return dict(loss_pha_laplacian=loss) def unknown_l1_loss(self, sample_map, preds, targets): scale = sample_map.shape[0]*262144/torch.sum(sample_map) # scale = 1 loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale return dict(unknown_l1_loss=loss) def known_l1_loss(self, sample_map, preds, targets): new_sample_map = torch.zeros_like(sample_map) new_sample_map[sample_map==0] = 1 if torch.sum(new_sample_map) == 0: scale = 0 else: scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map) # scale = 1 loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale return dict(known_l1_loss=loss) def forward(self, sample_map, preds, targets): losses = dict() for k in self.losses: if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty': losses.update(getattr(self, k)(sample_map, preds, targets)) else: losses.update(getattr(self, k)(preds, targets)) return losses #-----------------Laplacian Loss-------------------------# def laplacian_loss(pred, true, max_levels=5): kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) true_pyramid = laplacian_pyramid(true, kernel, max_levels) loss = 0 for level in range(max_levels): loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) return loss / max_levels def laplacian_pyramid(img, kernel, max_levels): current = img pyramid = [] for _ in range(max_levels): current = crop_to_even_size(current) down = downsample(current, kernel) up = upsample(down, kernel) diff = current - up pyramid.append(diff) current = down return pyramid def gauss_kernel(device='cpu', dtype=torch.float32): kernel = torch.tensor([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]], device=device, dtype=dtype) kernel /= 256 kernel = kernel[None, None, :, :] return kernel def gauss_convolution(img, kernel): B, C, H, W = img.shape img = img.reshape(B * C, 1, H, W) img = F.pad(img, (2, 2, 2, 2), mode='reflect') img = F.conv2d(img, kernel) img = img.reshape(B, C, H, W) return img def downsample(img, kernel): img = gauss_convolution(img, kernel) img = img[:, :, ::2, ::2] return img def upsample(img, kernel): B, C, H, W = img.shape out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) out[:, :, ::2, ::2] = img * 4 out = gauss_convolution(out, kernel) return out def crop_to_even_size(img): H, W = img.shape[2:] H = H - H % 2 W = W - W % 2 return img[:, :, :H, :W]