|
from collections.abc import Iterable |
|
|
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint, checkpoint_sequential |
|
|
|
|
|
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): |
|
assert isinstance(model, nn.Module) |
|
|
|
def set_attr(module): |
|
module.grad_checkpointing = True |
|
module.fp32_attention = use_fp32_attention |
|
module.grad_checkpointing_step = gc_step |
|
|
|
model.apply(set_attr) |
|
|
|
|
|
def auto_grad_checkpoint(module, *args, **kwargs): |
|
if getattr(module, "grad_checkpointing", False): |
|
if not isinstance(module, Iterable): |
|
return checkpoint(module, *args, **kwargs) |
|
gc_step = module[0].grad_checkpointing_step |
|
return checkpoint_sequential(module, gc_step, *args, **kwargs) |
|
return module(*args, **kwargs) |
|
|