# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor, Size from torch.cuda.amp import autocast __all__ = ["LayerNorm"] try: from apex.normalization import FusedLayerNorm as _FusedLayerNorm has_fused_layernorm = True class FusedLayerNorm(_FusedLayerNorm): @torch.jit.unused def forward(self, x): if not x.is_cuda: return super().forward(x) else: with torch.cuda.device(x.device): return super().forward(x) except ImportError: has_fused_layernorm = False def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): if torch.jit.is_scripting(): export = True if not export and torch.cuda.is_available() and has_fused_layernorm: return FusedLayerNorm(normalized_shape, eps, elementwise_affine) else: return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) class FP16LayerNorm(torch.nn.LayerNorm): def forward(self, input): with autocast(enabled=False): return F.layer_norm(input.half(), self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)