herrius's picture
Upload 259 files
32b542e
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, Size
from torch.cuda.amp import autocast
__all__ = ["LayerNorm"]
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
has_fused_layernorm = True
class FusedLayerNorm(_FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
has_fused_layernorm = False
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if torch.jit.is_scripting():
export = True
if not export and torch.cuda.is_available() and has_fused_layernorm:
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
else:
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class FP16LayerNorm(torch.nn.LayerNorm):
def forward(self, input):
with autocast(enabled=False):
return F.layer_norm(input.half(), self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)