File size: 1,873 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
return F.nll_loss(
lprobs,
target,
ignore_index=ignore_index,
reduction=reduction,
)
try:
import xentropy_cuda
from apex.contrib import xentropy
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
if logits.device == torch.device("cpu"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else:
if not getattr(cross_entropy, "_has_logged_once", False):
logger.info("using fused cross entropy")
cross_entropy._has_logged_once = True
half_to_float = logits.dtype == torch.half
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
logits,
target,
0.0,
ignore_index,
half_to_float,
)
if reduction == "sum":
return losses.sum()
elif reduction == "mean":
if ignore_index >= 0:
return losses.sum() / target.ne(ignore_index).sum()
else:
return losses.mean()
elif reduction == "none":
return losses
else:
raise NotImplementedError
except ImportError:
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
|