import torch.nn as nn import torch class RMSNorm(nn.Module): def __init__(self, d, p=-1., eps=1e-8, bias=False): """ Root Mean Square Layer Normalization :param d: model size :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) :param eps: epsilon value, default 1e-8 :param bias: whether use bias term for RMSNorm, disabled by default because RMSNorm doesn't enforce re-centering invariance. """ super(RMSNorm, self).__init__() self.eps = eps self.d = d self.p = p self.bias = bias self.scale = nn.Parameter(torch.ones(d)) self.register_parameter("scale", self.scale) if self.bias: self.offset = nn.Parameter(torch.zeros(d)) self.register_parameter("offset", self.offset) def forward(self, x): if self.p < 0. or self.p > 1.: norm_x = x.norm(2, dim=-1, keepdim=True) d_x = self.d else: partial_size = int(self.d * self.p) partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) norm_x = partial_x.norm(2, dim=-1, keepdim=True) d_x = partial_size rms_x = norm_x * d_x ** (-1. / 2) x_normed = x / (rms_x + self.eps) if self.bias: return self.scale * x_normed + self.offset return self.scale * x_normed