File size: 749 Bytes
2cddd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)

import torch.nn as nn


class ReprReconstructLoss(nn.Module):
    def __init__(self, loss_type: str):
        super().__init__()
        if loss_type.lower() == "l1":
            self.loss_metric = nn.L1Loss()
        elif loss_type.lower() == "l2":
            self.loss_metric = nn.MSELoss()
        else:
            raise NotImplementedError(f"Unsupported loss type: {loss_type}")

    def forward(self, pred, target):
        return self.loss_metric(pred, target)