File size: 542 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Loss functions."""
import torch.nn as nn
_LOSSES = {
"cross_entropy": nn.CrossEntropyLoss,
"bce": nn.BCELoss,
"bce_logit": nn.BCEWithLogitsLoss,
}
def get_loss_func(loss_name):
"""
Retrieve the loss given the loss name.
Args (int):
loss_name: the name of the loss to use.
"""
if loss_name not in _LOSSES.keys():
raise NotImplementedError("Loss {} is not supported".format(loss_name))
return _LOSSES[loss_name]
|