maxall4 commited on
Commit
e87428b
1 Parent(s): 567369e

Support gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +24 -0
modeling_hyena.py CHANGED
@@ -50,8 +50,32 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
50
  def force_dtype(self):
51
  self.backbone.to_bfloat16_except_poles_residues()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
54
  self.backbone.gradient_checkpointing = enable
 
55
 
56
  def get_input_embeddings(self):
57
  return self.backbone.embedding_layer
 
50
  def force_dtype(self):
51
  self.backbone.to_bfloat16_except_poles_residues()
52
 
53
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
54
+ if not self.supports_gradient_checkpointing:
55
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
56
+
57
+ if gradient_checkpointing_kwargs is None:
58
+ gradient_checkpointing_kwargs = {"use_reentrant": True}
59
+
60
+ # TODO support deepspeed checkpoint
61
+ gradient_checkpointing_func = functools.partial(
62
+ torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
63
+ )
64
+
65
+ self._set_gradient_checkpointing(
66
+ enable=True, gradient_checkpointing_func=gradient_checkpointing_func
67
+ )
68
+
69
+ if getattr(self, "_hf_peft_config_loaded", False):
70
+ # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
71
+ # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
72
+ # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
73
+ # the gradients to make sure the gradient flows.
74
+ self.enable_input_require_grads()
75
+
76
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
77
  self.backbone.gradient_checkpointing = enable
78
+ self.backbone._gradient_checkpointing_func = gradient_checkpointing_func
79
 
80
  def get_input_embeddings(self):
81
  return self.backbone.embedding_layer