Fix precision error
Browse files- modeling_chatglm.py +4 -20
modeling_chatglm.py
CHANGED
@@ -3,9 +3,7 @@
|
|
3 |
import math
|
4 |
import copy
|
5 |
import warnings
|
6 |
-
import re
|
7 |
import sys
|
8 |
-
import functools
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
11 |
import torch.nn.functional as F
|
@@ -177,14 +175,13 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|
177 |
|
178 |
|
179 |
class RMSNorm(torch.nn.Module):
|
180 |
-
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None,
|
181 |
super().__init__()
|
182 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
183 |
self.eps = eps
|
184 |
-
self.quantized = quantized
|
185 |
|
186 |
def forward(self, hidden_states: torch.Tensor):
|
187 |
-
if
|
188 |
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
189 |
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
190 |
return self.weight * x_normed
|
@@ -521,14 +518,7 @@ class GLMBlock(torch.nn.Module):
|
|
521 |
|
522 |
self.fp32_residual_connection = config.fp32_residual_connection
|
523 |
|
524 |
-
if config.rmsnorm
|
525 |
-
if config.quantization_bit != 0:
|
526 |
-
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
527 |
-
else:
|
528 |
-
LayerNormFunc = RMSNorm
|
529 |
-
else:
|
530 |
-
LayerNormFunc = LayerNorm
|
531 |
-
|
532 |
# Layernorm on the input data.
|
533 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
534 |
dtype=config.torch_dtype)
|
@@ -606,13 +596,7 @@ class GLMTransformer(torch.nn.Module):
|
|
606 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
607 |
|
608 |
if self.post_layer_norm:
|
609 |
-
if config.rmsnorm
|
610 |
-
if config.quantization_bit != 0:
|
611 |
-
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
612 |
-
else:
|
613 |
-
LayerNormFunc = RMSNorm
|
614 |
-
else:
|
615 |
-
LayerNormFunc = LayerNorm
|
616 |
# Final layer norm before output.
|
617 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
618 |
dtype=config.torch_dtype)
|
|
|
3 |
import math
|
4 |
import copy
|
5 |
import warnings
|
|
|
6 |
import sys
|
|
|
7 |
import torch
|
8 |
import torch.utils.checkpoint
|
9 |
import torch.nn.functional as F
|
|
|
175 |
|
176 |
|
177 |
class RMSNorm(torch.nn.Module):
|
178 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
179 |
super().__init__()
|
180 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
181 |
self.eps = eps
|
|
|
182 |
|
183 |
def forward(self, hidden_states: torch.Tensor):
|
184 |
+
if hidden_states == torch.bfloat16:
|
185 |
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
186 |
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
187 |
return self.weight * x_normed
|
|
|
518 |
|
519 |
self.fp32_residual_connection = config.fp32_residual_connection
|
520 |
|
521 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
# Layernorm on the input data.
|
523 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
524 |
dtype=config.torch_dtype)
|
|
|
596 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
597 |
|
598 |
if self.post_layer_norm:
|
599 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
# Final layer norm before output.
|
601 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
602 |
dtype=config.torch_dtype)
|