File size: 2,025 Bytes
6e73cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import gc
import torch
from composer.core import Callback, State
from composer.loggers import Logger
def gc_cuda():
"""Garbage collect Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
class ScheduledGarbageCollector(Callback):
"""Disable automatic garbage collection and collect garbage at interval.
Args:
batch_interval (int): Number of batches between checkpoints call to gc.collect()
eval_keep_disabled (bool): keep gc disabled during eval (default: False)
"""
def __init__(
self,
batch_interval: int,
eval_keep_disabled: bool = False,
):
self.batch_interval = batch_interval
self.eval_keep_disabled = eval_keep_disabled
self.gc_init_state = None
def fit_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
# cache if automatic garbage collection is enabled; reset at fit_end
self.gc_init_state = gc.isenabled()
# disable automatic garbage collection
gc.disable()
gc_cuda()
def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
gc_cuda()
# reset automatic garbage collection at fit_end
if self.gc_init_state:
gc.enable()
else:
gc.disable()
def before_dataloader(self, state: State, logger: Logger) -> None:
del logger # unused
if state.timestamp.batch.value % self.batch_interval == 0:
gc_cuda()
def eval_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
gc_cuda()
if not self.eval_keep_disabled:
gc.enable()
def eval_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if not self.eval_keep_disabled:
gc.disable()
gc_cuda()
|