File size: 628 Bytes
6931c7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
#!/usr/bin/env python
import torch.nn.functional as F
import torch.nn as nn
def calc_vq_loss(pred, target, quant_loss, quant_loss_weight=1.0, alpha=1.0):
""" function that computes the various components of the VQ loss """
rec_loss = nn.L1Loss()(pred, target)
## loss is VQ reconstruction + weighted pre-computed quantization loss
quant_loss = quant_loss.mean()
return quant_loss * quant_loss_weight + rec_loss, [rec_loss, quant_loss]
def calc_logit_loss(pred, target):
""" Cross entropy loss wrapper """
loss = F.cross_entropy(pred.reshape(-1, pred.size(-1)), target.reshape(-1))
return loss
|