|
import torch
|
|
from metrics.registry import LOSSFUNC
|
|
from .abstract_loss_func import AbstractLossClass
|
|
|
|
|
|
def mahalanobis_distance(values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor) -> torch.Tensor:
|
|
"""Compute the batched mahalanobis distance.
|
|
|
|
values is a batch of feature vectors.
|
|
mean is either the mean of the distribution to compare, or a second
|
|
batch of feature vectors.
|
|
inv_covariance is the inverse covariance of the target distribution.
|
|
"""
|
|
assert values.dim() == 2
|
|
assert 1 <= mean.dim() <= 2
|
|
assert inv_covariance.dim() == 2
|
|
assert values.shape[1] == mean.shape[-1]
|
|
assert mean.shape[-1] == inv_covariance.shape[0]
|
|
assert inv_covariance.shape[0] == inv_covariance.shape[1]
|
|
|
|
if mean.dim() == 1:
|
|
mean = mean.unsqueeze(0)
|
|
x_mu = values - mean
|
|
|
|
dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu)
|
|
|
|
return dist.sqrt()
|
|
|
|
|
|
@LOSSFUNC.register_module(module_name="patch_consistency_loss")
|
|
class PatchConsistencyLoss(AbstractLossClass):
|
|
def __init__(self, c_real, c_fake, c_cross):
|
|
super().__init__()
|
|
self.c_real = c_real
|
|
self.c_fake = c_fake
|
|
self.c_cross = c_cross
|
|
|
|
def forward(self, attention_map_real, attention_map_fake, feature_patch, real_feature_mean, real_inv_covariance,
|
|
fake_feature_mean, fake_inv_covariance, labels):
|
|
|
|
B, H, W, C = feature_patch.size()
|
|
dist_real = mahalanobis_distance(feature_patch.reshape(B * H * W, C), real_feature_mean.cuda(),
|
|
real_inv_covariance.cuda())
|
|
dist_fake = mahalanobis_distance(feature_patch.reshape(B * H * W, C), fake_feature_mean.cuda(),
|
|
fake_inv_covariance.cuda())
|
|
fake_indices = torch.where(labels == 1.0)[0]
|
|
index_map = torch.relu(dist_real - dist_fake).reshape((B, H, W))[fake_indices, :]
|
|
|
|
|
|
if attention_map_real.shape[0] == 0:
|
|
loss_real = 0
|
|
else:
|
|
B, PP, PP = attention_map_real.shape
|
|
c_matrix = (1 - self.c_real) * torch.eye(PP).cuda() + self.c_real * torch.ones(PP).cuda()
|
|
c_matrix = c_matrix.expand(B, -1, -1)
|
|
loss_real = torch.sum(torch.abs(attention_map_real - c_matrix)) / (B * (PP * PP - PP))
|
|
|
|
if attention_map_fake.shape[0] == 0:
|
|
loss_fake = 0
|
|
else:
|
|
B, PP, PP = attention_map_fake.shape
|
|
c_matrix = []
|
|
for b in range(B):
|
|
fake_indices = torch.where(index_map[b].reshape(-1) > 0)[0]
|
|
real_indices = torch.where(index_map[b].reshape(-1) <= 0)[0]
|
|
tmp = torch.zeros((PP, PP)).cuda() + self.c_cross
|
|
for i in fake_indices:
|
|
tmp[i, fake_indices] = self.c_fake
|
|
for i in real_indices:
|
|
tmp[i, real_indices] = self.c_real
|
|
c_matrix.append(tmp)
|
|
|
|
c_matrix = torch.stack(c_matrix).cuda()
|
|
loss_fake = torch.sum(torch.abs(attention_map_fake - c_matrix)) / (B * (PP * PP - PP))
|
|
|
|
return loss_real + loss_fake
|
|
|