winglian's picture
Unsloth gradient checkpointing offload (#1528)
6319da1 unverified
raw
history blame contribute delete
375 Bytes
"""custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
*args,
)