# Copyright (c) ByteDance, Inc. and its affiliates. | |
# Copyright (c) Chutong Meng | |
# | |
# This source code is licensed under the CC BY-NC license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
import torch.nn as nn | |
class ReprReconstructLoss(nn.Module): | |
def __init__(self, loss_type: str): | |
super().__init__() | |
if loss_type.lower() == "l1": | |
self.loss_metric = nn.L1Loss() | |
elif loss_type.lower() == "l2": | |
self.loss_metric = nn.MSELoss() | |
else: | |
raise NotImplementedError(f"Unsupported loss type: {loss_type}") | |
def forward(self, pred, target): | |
return self.loss_metric(pred, target) | |