"""custom checkpointing utils""" from axolotl.utils.gradient_checkpointing.unsloth import ( Unsloth_Offloaded_Gradient_Checkpointer, ) def hf_grad_checkpoint_unsloth_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument return Unsloth_Offloaded_Gradient_Checkpointer.apply( decoder_layer.__self__, *args, )