#!/usr/bin/env python | |
import torch.nn.functional as F | |
import torch.nn as nn | |
def calc_vq_loss(pred, target, quant_loss, quant_loss_weight=1.0, alpha=1.0): | |
""" function that computes the various components of the VQ loss """ | |
rec_loss = nn.L1Loss()(pred, target) | |
## loss is VQ reconstruction + weighted pre-computed quantization loss | |
quant_loss = quant_loss.mean() | |
return quant_loss * quant_loss_weight + rec_loss, [rec_loss, quant_loss] | |
def calc_logit_loss(pred, target): | |
""" Cross entropy loss wrapper """ | |
loss = F.cross_entropy(pred.reshape(-1, pred.size(-1)), target.reshape(-1)) | |
return loss | |