Matte-Anything / modeling /criterion /matting_criterion.py
Jeney's picture
Upload folder using huggingface_hub
6a89c74
import torch
import torch.nn as nn
import torch.nn.functional as F
class MattingCriterion(nn.Module):
def __init__(self,
*,
losses,
):
super(MattingCriterion, self).__init__()
self.losses = losses
def loss_gradient_penalty(self, sample_map ,preds, targets):
preds = preds['phas']
targets = targets['phas']
#sample_map for unknown area
scale = sample_map.shape[0]*262144/torch.sum(sample_map)
#gradient in x
sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
#gradient in y
sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
#loss
loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \
F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \
0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \
0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale)
return dict(loss_gradient_penalty=loss)
def loss_pha_laplacian(self, preds, targets):
assert 'phas' in preds and 'phas' in targets
loss = laplacian_loss(preds['phas'], targets['phas'])
return dict(loss_pha_laplacian=loss)
def unknown_l1_loss(self, sample_map, preds, targets):
scale = sample_map.shape[0]*262144/torch.sum(sample_map)
# scale = 1
loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale
return dict(unknown_l1_loss=loss)
def known_l1_loss(self, sample_map, preds, targets):
new_sample_map = torch.zeros_like(sample_map)
new_sample_map[sample_map==0] = 1
if torch.sum(new_sample_map) == 0:
scale = 0
else:
scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
# scale = 1
loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
return dict(known_l1_loss=loss)
def forward(self, sample_map, preds, targets):
losses = dict()
for k in self.losses:
if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty':
losses.update(getattr(self, k)(sample_map, preds, targets))
else:
losses.update(getattr(self, k)(preds, targets))
return losses
#-----------------Laplacian Loss-------------------------#
def laplacian_loss(pred, true, max_levels=5):
kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
true_pyramid = laplacian_pyramid(true, kernel, max_levels)
loss = 0
for level in range(max_levels):
loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
return loss / max_levels
def laplacian_pyramid(img, kernel, max_levels):
current = img
pyramid = []
for _ in range(max_levels):
current = crop_to_even_size(current)
down = downsample(current, kernel)
up = upsample(down, kernel)
diff = current - up
pyramid.append(diff)
current = down
return pyramid
def gauss_kernel(device='cpu', dtype=torch.float32):
kernel = torch.tensor([[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]], device=device, dtype=dtype)
kernel /= 256
kernel = kernel[None, None, :, :]
return kernel
def gauss_convolution(img, kernel):
B, C, H, W = img.shape
img = img.reshape(B * C, 1, H, W)
img = F.pad(img, (2, 2, 2, 2), mode='reflect')
img = F.conv2d(img, kernel)
img = img.reshape(B, C, H, W)
return img
def downsample(img, kernel):
img = gauss_convolution(img, kernel)
img = img[:, :, ::2, ::2]
return img
def upsample(img, kernel):
B, C, H, W = img.shape
out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
out[:, :, ::2, ::2] = img * 4
out = gauss_convolution(out, kernel)
return out
def crop_to_even_size(img):
H, W = img.shape[2:]
H = H - H % 2
W = W - W % 2
return img[:, :, :H, :W]