|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class BCELoss(nn.Module): |
|
def forward(self, prediction, target): |
|
loss = F.binary_cross_entropy_with_logits(prediction,target) |
|
return loss, {} |
|
|
|
|
|
class BCELossWithQuant(nn.Module): |
|
def __init__(self, codebook_weight=1.): |
|
super().__init__() |
|
self.codebook_weight = codebook_weight |
|
|
|
def forward(self, qloss, target, prediction, split): |
|
bce_loss = F.binary_cross_entropy_with_logits(prediction,target) |
|
loss = bce_loss + self.codebook_weight*qloss |
|
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), |
|
"{}/bce_loss".format(split): bce_loss.detach().mean(), |
|
"{}/quant_loss".format(split): qloss.detach().mean() |
|
} |
|
|