Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
class RMSNorm(nn.Module): | |
def __init__(self, d, p=-1., eps=1e-8, bias=False): | |
""" | |
Root Mean Square Layer Normalization | |
:param d: model size | |
:param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) | |
:param eps: epsilon value, default 1e-8 | |
:param bias: whether use bias term for RMSNorm, disabled by | |
default because RMSNorm doesn't enforce re-centering invariance. | |
""" | |
super(RMSNorm, self).__init__() | |
self.eps = eps | |
self.d = d | |
self.p = p | |
self.bias = bias | |
self.scale = nn.Parameter(torch.ones(d)) | |
self.register_parameter("scale", self.scale) | |
if self.bias: | |
self.offset = nn.Parameter(torch.zeros(d)) | |
self.register_parameter("offset", self.offset) | |
def forward(self, x): | |
if self.p < 0. or self.p > 1.: | |
norm_x = x.norm(2, dim=-1, keepdim=True) | |
d_x = self.d | |
else: | |
partial_size = int(self.d * self.p) | |
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) | |
norm_x = partial_x.norm(2, dim=-1, keepdim=True) | |
d_x = partial_size | |
rms_x = norm_x * d_x ** (-1. / 2) | |
x_normed = x / (rms_x + self.eps) | |
if self.bias: | |
return self.scale * x_normed + self.offset | |
return self.scale * x_normed | |