|
|
|
|
|
|
|
import gc |
|
|
|
import torch |
|
from composer.core import Callback, State |
|
from composer.loggers import Logger |
|
|
|
|
|
def gc_cuda(): |
|
"""Garbage collect Torch (CUDA) memory.""" |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
class ScheduledGarbageCollector(Callback): |
|
"""Disable automatic garbage collection and collect garbage at interval. |
|
|
|
Args: |
|
batch_interval (int): Number of batches between checkpoints call to gc.collect() |
|
eval_keep_disabled (bool): keep gc disabled during eval (default: False) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
batch_interval: int, |
|
eval_keep_disabled: bool = False, |
|
): |
|
self.batch_interval = batch_interval |
|
self.eval_keep_disabled = eval_keep_disabled |
|
self.gc_init_state = None |
|
|
|
def fit_start(self, state: State, logger: Logger) -> None: |
|
del state, logger |
|
|
|
|
|
self.gc_init_state = gc.isenabled() |
|
|
|
|
|
gc.disable() |
|
gc_cuda() |
|
|
|
def fit_end(self, state: State, logger: Logger) -> None: |
|
del state, logger |
|
|
|
gc_cuda() |
|
|
|
|
|
if self.gc_init_state: |
|
gc.enable() |
|
else: |
|
gc.disable() |
|
|
|
def before_dataloader(self, state: State, logger: Logger) -> None: |
|
del logger |
|
|
|
if state.timestamp.batch.value % self.batch_interval == 0: |
|
gc_cuda() |
|
|
|
def eval_start(self, state: State, logger: Logger) -> None: |
|
del state, logger |
|
|
|
gc_cuda() |
|
if not self.eval_keep_disabled: |
|
gc.enable() |
|
|
|
def eval_end(self, state: State, logger: Logger) -> None: |
|
del state, logger |
|
|
|
if not self.eval_keep_disabled: |
|
gc.disable() |
|
|
|
gc_cuda() |
|
|