|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def discriminator_adv_loss(disc_real_outputs, disc_generated_outputs): |
|
loss = 0 |
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): |
|
dr_fun, dr_dir = dr |
|
dg_fun, dg_dir = dg |
|
r_loss_fun = torch.mean(F.softplus(1 - dr_fun) ** 2) |
|
g_loss_fun = torch.mean(F.softplus(dg_fun) ** 2) |
|
r_loss_dir = torch.mean(F.softplus(1 - dr_dir) ** 2) |
|
g_loss_dir = torch.mean(-F.softplus(1 - dg_dir) ** 2) |
|
r_loss = r_loss_fun + r_loss_dir |
|
g_loss = g_loss_fun + g_loss_dir |
|
loss += (r_loss + g_loss) |
|
|
|
return loss / len(disc_generated_outputs) |
|
|
|
|
|
def generator_adv_loss(disc_outputs): |
|
loss = 0 |
|
for dg in disc_outputs: |
|
l = torch.mean(F.softplus(1 - dg) ** 2) |
|
loss += l |
|
|
|
return loss / len(disc_outputs) |
|
|
|
|
|
class GeneratorAdversarialLoss(torch.nn.Module): |
|
|
|
def __init__(self, |
|
average_by_discriminators=True, |
|
loss_type="mse", ): |
|
"""Initialize GeneratorAversarialLoss module.""" |
|
super().__init__() |
|
self.average_by_discriminators = average_by_discriminators |
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." |
|
if loss_type == "mse": |
|
self.criterion = self._mse_loss |
|
else: |
|
self.criterion = self._hinge_loss |
|
|
|
def forward(self, outputs): |
|
""" |
|
Calcualate generator adversarial loss. |
|
|
|
Args: |
|
outputs (Tensor or list): Discriminator outputs or list of |
|
discriminator outputs. |
|
|
|
Returns: |
|
Tensor: Generator adversarial loss value. |
|
""" |
|
if isinstance(outputs, (tuple, list)): |
|
adv_loss = 0.0 |
|
for i, outputs_ in enumerate(outputs): |
|
if isinstance(outputs_, (tuple, list)): |
|
outputs_ = outputs_[-1] |
|
adv_loss = adv_loss + self.criterion(outputs_) |
|
if self.average_by_discriminators: |
|
adv_loss /= i + 1 |
|
else: |
|
adv_loss = self.criterion(outputs) |
|
|
|
return adv_loss |
|
|
|
def _mse_loss(self, x): |
|
return F.mse_loss(x, x.new_ones(x.size())) |
|
|
|
def _hinge_loss(self, x): |
|
return -x.mean() |
|
|
|
|
|
class DiscriminatorAdversarialLoss(torch.nn.Module): |
|
|
|
def __init__(self, |
|
average_by_discriminators=True, |
|
loss_type="mse", ): |
|
super().__init__() |
|
self.average_by_discriminators = average_by_discriminators |
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." |
|
if loss_type == "mse": |
|
self.fake_criterion = self._mse_fake_loss |
|
self.real_criterion = self._mse_real_loss |
|
else: |
|
self.fake_criterion = self._hinge_fake_loss |
|
self.real_criterion = self._hinge_real_loss |
|
|
|
def forward(self, outputs_hat, outputs): |
|
""" |
|
Calcualate discriminator adversarial loss. |
|
|
|
Args: |
|
outputs_hat (Tensor or list): Discriminator outputs or list of |
|
discriminator outputs calculated from generator outputs. |
|
outputs (Tensor or list): Discriminator outputs or list of |
|
discriminator outputs calculated from groundtruth. |
|
|
|
Returns: |
|
Tensor: Discriminator real loss value. |
|
Tensor: Discriminator fake loss value. |
|
""" |
|
if isinstance(outputs, (tuple, list)): |
|
real_loss = 0.0 |
|
fake_loss = 0.0 |
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): |
|
if isinstance(outputs_hat_, (tuple, list)): |
|
outputs_hat_ = outputs_hat_[-1] |
|
outputs_ = outputs_[-1] |
|
real_loss = real_loss + self.real_criterion(outputs_) |
|
fake_loss = fake_loss + self.fake_criterion(outputs_hat_) |
|
if self.average_by_discriminators: |
|
fake_loss /= i + 1 |
|
real_loss /= i + 1 |
|
else: |
|
real_loss = self.real_criterion(outputs) |
|
fake_loss = self.fake_criterion(outputs_hat) |
|
|
|
return real_loss + fake_loss |
|
|
|
def _mse_real_loss(self, x): |
|
return F.mse_loss(x, x.new_ones(x.size())) |
|
|
|
def _mse_fake_loss(self, x): |
|
return F.mse_loss(x, x.new_zeros(x.size())) |
|
|
|
def _hinge_real_loss(self, x): |
|
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) |
|
|
|
def _hinge_fake_loss(self, x): |
|
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) |
|
|