Merged mosaicml/mpt-7b@a78c1f/norm.py into main
Browse files
norm.py
CHANGED
@@ -25,7 +25,7 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|
25 |
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
26 |
|
27 |
def rms_norm(x, weight=None, eps=1e-05):
|
28 |
-
output = x
|
29 |
if weight is not None:
|
30 |
return output * weight
|
31 |
return output
|
|
|
25 |
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
26 |
|
27 |
def rms_norm(x, weight=None, eps=1e-05):
|
28 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
29 |
if weight is not None:
|
30 |
return output * weight
|
31 |
return output
|