File size: 2,025 Bytes
6e73cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import gc

import torch
from composer.core import Callback, State
from composer.loggers import Logger


def gc_cuda():
    """Garbage collect Torch (CUDA) memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


class ScheduledGarbageCollector(Callback):
    """Disable automatic garbage collection and collect garbage at interval.

    Args:
        batch_interval (int): Number of batches between checkpoints call to gc.collect()
        eval_keep_disabled (bool): keep gc disabled during eval (default: False)
    """

    def __init__(
        self,
        batch_interval: int,
        eval_keep_disabled: bool = False,
    ):
        self.batch_interval = batch_interval
        self.eval_keep_disabled = eval_keep_disabled
        self.gc_init_state = None

    def fit_start(self, state: State, logger: Logger) -> None:
        del state, logger  # unused

        # cache if automatic garbage collection is enabled; reset at fit_end
        self.gc_init_state = gc.isenabled()

        # disable automatic garbage collection
        gc.disable()
        gc_cuda()

    def fit_end(self, state: State, logger: Logger) -> None:
        del state, logger  # unused

        gc_cuda()

        # reset automatic garbage collection at fit_end
        if self.gc_init_state:
            gc.enable()
        else:
            gc.disable()

    def before_dataloader(self, state: State, logger: Logger) -> None:
        del logger  # unused

        if state.timestamp.batch.value % self.batch_interval == 0:
            gc_cuda()

    def eval_start(self, state: State, logger: Logger) -> None:
        del state, logger  # unused

        gc_cuda()
        if not self.eval_keep_disabled:
            gc.enable()

    def eval_end(self, state: State, logger: Logger) -> None:
        del state, logger  # unused

        if not self.eval_keep_disabled:
            gc.disable()

        gc_cuda()