yairschiff
commited on
Update modeling_rcps.py
Browse filesEnable prenorm = False to prevent returning residual in RCPSAddNormWrapper
- modeling_rcps.py +9 -3
modeling_rcps.py
CHANGED
@@ -101,11 +101,12 @@ class RCPSAddNormWrapper(RCPSWrapper):
|
|
101 |
def __init__(self, submodule: nn.Module):
|
102 |
super().__init__(submodule)
|
103 |
|
104 |
-
def forward(self, x, residual=None):
|
105 |
"""
|
106 |
Args:
|
107 |
x: Input tensor of shape (batch_size, seq_len, channels)
|
108 |
residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
|
|
|
109 |
"""
|
110 |
n_channels = x.shape[-1]
|
111 |
if residual is None:
|
@@ -123,7 +124,7 @@ class RCPSAddNormWrapper(RCPSWrapper):
|
|
123 |
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
|
124 |
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
|
125 |
|
126 |
-
return x, residual
|
127 |
|
128 |
|
129 |
class RCPSMambaBlock(nn.Module):
|
@@ -147,6 +148,11 @@ class RCPSMambaBlock(nn.Module):
|
|
147 |
self.mixer = RCPSWrapper(mixer_cls(dim))
|
148 |
norm_f = norm_cls(dim)
|
149 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
def forward(
|
152 |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
@@ -159,7 +165,7 @@ class RCPSMambaBlock(nn.Module):
|
|
159 |
inference_params: inference parameters for mixer.
|
160 |
"""
|
161 |
if not self.fused_add_norm:
|
162 |
-
hidden_states, residual = self.norm(hidden_states, residual=residual)
|
163 |
if self.residual_in_fp32:
|
164 |
residual = residual.to(torch.float32)
|
165 |
else:
|
|
|
101 |
def __init__(self, submodule: nn.Module):
|
102 |
super().__init__(submodule)
|
103 |
|
104 |
+
def forward(self, x, residual=None, prenorm=False):
|
105 |
"""
|
106 |
Args:
|
107 |
x: Input tensor of shape (batch_size, seq_len, channels)
|
108 |
residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
|
109 |
+
prenorm: Whether to return residual.
|
110 |
"""
|
111 |
n_channels = x.shape[-1]
|
112 |
if residual is None:
|
|
|
124 |
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
|
125 |
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
|
126 |
|
127 |
+
return x if not prenorm else (x, residual)
|
128 |
|
129 |
|
130 |
class RCPSMambaBlock(nn.Module):
|
|
|
148 |
self.mixer = RCPSWrapper(mixer_cls(dim))
|
149 |
norm_f = norm_cls(dim)
|
150 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
151 |
+
if self.fused_add_norm:
|
152 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
153 |
+
assert isinstance(
|
154 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
155 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
156 |
|
157 |
def forward(
|
158 |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
|
|
165 |
inference_params: inference parameters for mixer.
|
166 |
"""
|
167 |
if not self.fused_add_norm:
|
168 |
+
hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
|
169 |
if self.residual_in_fp32:
|
170 |
residual = residual.to(torch.float32)
|
171 |
else:
|