File size: 749 Bytes
2cddd11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import torch.nn as nn
class ReprReconstructLoss(nn.Module):
def __init__(self, loss_type: str):
super().__init__()
if loss_type.lower() == "l1":
self.loss_metric = nn.L1Loss()
elif loss_type.lower() == "l2":
self.loss_metric = nn.MSELoss()
else:
raise NotImplementedError(f"Unsupported loss type: {loss_type}")
def forward(self, pred, target):
return self.loss_metric(pred, target)
|