davidhd commited on
Commit
d918a9e
·
verified ·
1 Parent(s): e27bb0d

Update rmsnorm.py

Browse files

Modifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19

Files changed (1) hide show
  1. rmsnorm.py +5 -7
rmsnorm.py CHANGED
@@ -6,29 +6,27 @@ class RMSNorm(nn.Module):
6
  def __init__(self, dim: int, eps: float = 1e-6):
7
  """
8
  Initialize the RMSNorm normalization layer.
9
-
10
  Args:
11
  dim (int): The dimension of the input tensor.
12
  eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
-
14
  Attributes:
15
  eps (float): A small value added to the denominator for numerical stability.
16
  weight (nn.Parameter): Learnable scaling parameter.
17
-
18
  """
19
  super().__init__()
20
  self.eps = eps
21
  self.weight = nn.Parameter(torch.ones(dim))
22
 
 
 
 
23
  def forward(self, x):
24
  """
25
  Forward pass through the RMSNorm layer.
26
-
27
  Args:
28
  x (torch.Tensor): The input tensor.
29
-
30
  Returns:
31
  torch.Tensor: The output tensor after applying RMSNorm.
32
-
33
  """
34
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
 
 
6
  def __init__(self, dim: int, eps: float = 1e-6):
7
  """
8
  Initialize the RMSNorm normalization layer.
 
9
  Args:
10
  dim (int): The dimension of the input tensor.
11
  eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
 
12
  Attributes:
13
  eps (float): A small value added to the denominator for numerical stability.
14
  weight (nn.Parameter): Learnable scaling parameter.
 
15
  """
16
  super().__init__()
17
  self.eps = eps
18
  self.weight = nn.Parameter(torch.ones(dim))
19
 
20
+ def _norm(self, x):
21
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
22
+
23
  def forward(self, x):
24
  """
25
  Forward pass through the RMSNorm layer.
 
26
  Args:
27
  x (torch.Tensor): The input tensor.
 
28
  Returns:
29
  torch.Tensor: The output tensor after applying RMSNorm.
 
30
  """
31
+ output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
32
+ return output * self.weight