crystal-technologies's picture
Upload 2711 files
6e73cd3
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import gc
import torch
from composer.core import Callback, State
from composer.loggers import Logger
def gc_cuda():
"""Garbage collect Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
class ScheduledGarbageCollector(Callback):
"""Disable automatic garbage collection and collect garbage at interval.
Args:
batch_interval (int): Number of batches between checkpoints call to gc.collect()
eval_keep_disabled (bool): keep gc disabled during eval (default: False)
"""
def __init__(
self,
batch_interval: int,
eval_keep_disabled: bool = False,
):
self.batch_interval = batch_interval
self.eval_keep_disabled = eval_keep_disabled
self.gc_init_state = None
def fit_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
# cache if automatic garbage collection is enabled; reset at fit_end
self.gc_init_state = gc.isenabled()
# disable automatic garbage collection
gc.disable()
gc_cuda()
def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
gc_cuda()
# reset automatic garbage collection at fit_end
if self.gc_init_state:
gc.enable()
else:
gc.disable()
def before_dataloader(self, state: State, logger: Logger) -> None:
del logger # unused
if state.timestamp.batch.value % self.batch_interval == 0:
gc_cuda()
def eval_start(self, state: State, logger: Logger) -> None:
del state, logger # unused
gc_cuda()
if not self.eval_keep_disabled:
gc.enable()
def eval_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if not self.eval_keep_disabled:
gc.disable()
gc_cuda()