Update model.py
Browse files
model.py
CHANGED
@@ -381,7 +381,11 @@ class StripedHyena(nn.Module):
|
|
381 |
x = x * padding_mask[..., None]
|
382 |
|
383 |
for _, block in enumerate(self.blocks):
|
384 |
-
|
|
|
|
|
|
|
|
|
385 |
return x, None
|
386 |
|
387 |
def initialize_inference_params(self):
|
|
|
381 |
x = x * padding_mask[..., None]
|
382 |
|
383 |
for _, block in enumerate(self.blocks):
|
384 |
+
if self.gradient_checkpointing and self.training:
|
385 |
+
x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
|
386 |
+
else:
|
387 |
+
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
|
388 |
+
|
389 |
return x, None
|
390 |
|
391 |
def initialize_inference_params(self):
|