# -*- coding: utf-8 -*- import torch from typing import Callable, Iterable, Sequence, Union def checkpoint( func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], inputs: Sequence[torch.Tensor], params: Iterable[torch.Tensor], flag: bool, use_deepspeed: bool = False ): # Evaluate a function without caching intermediate activations, allowing for # reduced memory at the expense of extra compute in the backward pass. # :param func: the function to evaluate. # :param inputs: the argument sequence to pass to `func`. # :param params: a sequence of parameters `func` depends on but does not # explicitly take as arguments. # :param flag: if False, disable gradient checkpointing. # :param use_deepspeed: if True, use deepspeed if flag: args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return (None, None) + input_grads