|
import os |
|
import sys |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
class KeypointLoss(nn.Module): |
|
def __init__(self): |
|
super(KeypointLoss, self).__init__() |
|
|
|
def forward(self, pred_seq, gt_seq, gt_conf=None): |
|
|
|
if gt_conf is not None: |
|
gt_conf = gt_conf >= 0.01 |
|
return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean') |
|
else: |
|
return F.mse_loss(pred_seq, gt_seq) |
|
|
|
|
|
class KLLoss(nn.Module): |
|
def __init__(self, kl_tolerance): |
|
super(KLLoss, self).__init__() |
|
self.kl_tolerance = kl_tolerance |
|
|
|
def forward(self, mu, var, mul=1): |
|
kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64 |
|
kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1) |
|
|
|
if self.kl_tolerance is not None: |
|
|
|
|
|
|
|
|
|
|
|
kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda')) |
|
|
|
kld_loss = torch.mean(kld_loss) |
|
return kld_loss |
|
|
|
|
|
class L2KLLoss(nn.Module): |
|
def __init__(self, kl_tolerance): |
|
super(L2KLLoss, self).__init__() |
|
self.kl_tolerance = kl_tolerance |
|
|
|
def forward(self, x): |
|
|
|
kld_loss = torch.sum(x ** 2, dim=1) |
|
if self.kl_tolerance is not None: |
|
above_line = kld_loss[kld_loss > self.kl_tolerance] |
|
if len(above_line) > 0: |
|
kld_loss = torch.mean(kld_loss) |
|
else: |
|
kld_loss = 0 |
|
else: |
|
kld_loss = torch.mean(kld_loss) |
|
return kld_loss |
|
|
|
class L2RegLoss(nn.Module): |
|
def __init__(self): |
|
super(L2RegLoss, self).__init__() |
|
|
|
def forward(self, x): |
|
|
|
return torch.sum(x**2) |
|
|
|
|
|
class L2Loss(nn.Module): |
|
def __init__(self): |
|
super(L2Loss, self).__init__() |
|
|
|
def forward(self, x): |
|
|
|
return torch.sum(x ** 2) |
|
|
|
|
|
class AudioLoss(nn.Module): |
|
def __init__(self): |
|
super(AudioLoss, self).__init__() |
|
|
|
def forward(self, dynamics, gt_poses): |
|
|
|
mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1) |
|
gt = gt_poses - mean |
|
return F.mse_loss(dynamics, gt) |
|
|
|
L1Loss = nn.L1Loss |