|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor, Size |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
__all__ = ["LayerNorm"] |
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
|
|
|
has_fused_layernorm = True |
|
|
|
class FusedLayerNorm(_FusedLayerNorm): |
|
@torch.jit.unused |
|
def forward(self, x): |
|
if not x.is_cuda: |
|
return super().forward(x) |
|
else: |
|
with torch.cuda.device(x.device): |
|
return super().forward(x) |
|
|
|
except ImportError: |
|
has_fused_layernorm = False |
|
|
|
|
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): |
|
if torch.jit.is_scripting(): |
|
export = True |
|
if not export and torch.cuda.is_available() and has_fused_layernorm: |
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine) |
|
else: |
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
|
|
|
class FP16LayerNorm(torch.nn.LayerNorm): |
|
|
|
def forward(self, input): |
|
with autocast(enabled=False): |
|
return F.layer_norm(input.half(), self.normalized_shape, self.weight.half(), self.bias.half(), self.eps) |
|
|