# Copyright (c) ByteDance, Inc. and its affiliates. # Copyright (c) Chutong Meng # # This source code is licensed under the CC BY-NC license found in the # LICENSE file in the root directory of this source tree. # Based on AudioDec (https://github.com/facebookresearch/AudioDec) import torch.nn as nn class ReprReconstructLoss(nn.Module): def __init__(self, loss_type: str): super().__init__() if loss_type.lower() == "l1": self.loss_metric = nn.L1Loss() elif loss_type.lower() == "l2": self.loss_metric = nn.MSELoss() else: raise NotImplementedError(f"Unsupported loss type: {loss_type}") def forward(self, pred, target): return self.loss_metric(pred, target)