File size: 1,873 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
import torch.nn.functional as F


logger = logging.getLogger(__name__)


def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
    lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    return F.nll_loss(
        lprobs,
        target,
        ignore_index=ignore_index,
        reduction=reduction,
    )


try:
    import xentropy_cuda
    from apex.contrib import xentropy

    def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
        if logits.device == torch.device("cpu"):
            return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
        else:
            if not getattr(cross_entropy, "_has_logged_once", False):
                logger.info("using fused cross entropy")
                cross_entropy._has_logged_once = True

            half_to_float = logits.dtype == torch.half
            losses = xentropy.SoftmaxCrossEntropyLoss.apply(
                logits,
                target,
                0.0,
                ignore_index,
                half_to_float,
            )
            if reduction == "sum":
                return losses.sum()
            elif reduction == "mean":
                if ignore_index >= 0:
                    return losses.sum() / target.ne(ignore_index).sum()
                else:
                    return losses.mean()
            elif reduction == "none":
                return losses
            else:
                raise NotImplementedError


except ImportError:

    def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
        return _cross_entropy_pytorch(logits, target, ignore_index, reduction)