Support gradient checkpointing

#3
by maxall4 - opened
Files changed (2) hide show
  1. model.py +7 -1
  2. modeling_hyena.py +25 -0
model.py CHANGED
@@ -350,6 +350,8 @@ class StripedHyena(nn.Module):
350
  self.blocks = nn.ModuleList(
351
  get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
352
  )
 
 
353
 
354
  def forward(self, x, inference_params_dict=None, padding_mask=None):
355
  L = x.shape[1]
@@ -379,7 +381,11 @@ class StripedHyena(nn.Module):
379
  x = x * padding_mask[..., None]
380
 
381
  for _, block in enumerate(self.blocks):
382
- x, _ = block(x, inference_params=None, padding_mask=padding_mask)
 
 
 
 
383
  return x, None
384
 
385
  def initialize_inference_params(self):
 
350
  self.blocks = nn.ModuleList(
351
  get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
352
  )
353
+ self.gradient_checkpointing = False
354
+ self._gradient_checkpointing_func = None
355
 
356
  def forward(self, x, inference_params_dict=None, padding_mask=None):
357
  L = x.shape[1]
 
381
  x = x * padding_mask[..., None]
382
 
383
  for _, block in enumerate(self.blocks):
384
+ if self.gradient_checkpointing and self.training:
385
+ x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
386
+ else:
387
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
388
+
389
  return x, None
390
 
391
  def initialize_inference_params(self):
modeling_hyena.py CHANGED
@@ -2,6 +2,7 @@
2
  """StripedHyena custom code port for the Hugging Face Hub"""
3
 
4
  import torch
 
5
  from torch.nn import functional as F
6
  from .configuration_hyena import StripedHyenaConfig
7
  from transformers import PreTrainedModel
@@ -50,8 +51,32 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
50
  def force_dtype(self):
51
  self.backbone.to_bfloat16_except_poles_residues()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
54
  self.backbone.gradient_checkpointing = enable
 
55
 
56
  def get_input_embeddings(self):
57
  return self.backbone.embedding_layer
 
2
  """StripedHyena custom code port for the Hugging Face Hub"""
3
 
4
  import torch
5
+ import functools
6
  from torch.nn import functional as F
7
  from .configuration_hyena import StripedHyenaConfig
8
  from transformers import PreTrainedModel
 
51
  def force_dtype(self):
52
  self.backbone.to_bfloat16_except_poles_residues()
53
 
54
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
55
+ if not self.supports_gradient_checkpointing:
56
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
57
+
58
+ if gradient_checkpointing_kwargs is None:
59
+ gradient_checkpointing_kwargs = {"use_reentrant": True}
60
+
61
+ # TODO support deepspeed checkpoint
62
+ gradient_checkpointing_func = functools.partial(
63
+ torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
64
+ )
65
+
66
+ self._set_gradient_checkpointing(
67
+ enable=True, gradient_checkpointing_func=gradient_checkpointing_func
68
+ )
69
+
70
+ if getattr(self, "_hf_peft_config_loaded", False):
71
+ # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
72
+ # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
73
+ # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
74
+ # the gradients to make sure the gradient flows.
75
+ self.enable_input_require_grads()
76
+
77
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
78
  self.backbone.gradient_checkpointing = enable
79
+ self.backbone._gradient_checkpointing_func = gradient_checkpointing_func
80
 
81
  def get_input_embeddings(self):
82
  return self.backbone.embedding_layer