maxall4 commited on
Commit
2ec9f03
1 Parent(s): d849f5b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +5 -1
model.py CHANGED
@@ -381,7 +381,11 @@ class StripedHyena(nn.Module):
381
  x = x * padding_mask[..., None]
382
 
383
  for _, block in enumerate(self.blocks):
384
- x, _ = block(x, inference_params=None, padding_mask=padding_mask)
 
 
 
 
385
  return x, None
386
 
387
  def initialize_inference_params(self):
 
381
  x = x * padding_mask[..., None]
382
 
383
  for _, block in enumerate(self.blocks):
384
+ if self.gradient_checkpointing and self.training:
385
+ x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
386
+ else:
387
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
388
+
389
  return x, None
390
 
391
  def initialize_inference_params(self):