File size: 2,919 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List

from composer.core import Callback, State
from composer.loggers import Logger

__all__ = [
    'GlobalLRScaling',
    'LayerFreezing',
]

log = logging.getLogger(__name__)


class GlobalLRScaling(Callback):
    """GlobalLRScaling.

    This callback can be applied upon resuming a model checkpoint. Upon
    fit_start it will multiply the base LR by `lr_scale` and set the WD to be.

    `wd_pct` * `lr`.

    Args:
        lr_scale (float): Multiplicative factor to scale LR by
        wd_pct (float): Percentage of LR to set weight decay to.
    """

    def __init__(self, lr_scale: float, wd_pct: float = 0.0):
        self.lr_scale = lr_scale
        self.wd_pct = wd_pct

    def fit_start(self, state: State, logger: Logger) -> None:
        del logger  # unused

        if hasattr(state, 'optimizer') and state.optimizers is None:
            raise Exception('No optimizers defined')
        for optimizer in state.optimizers:
            for group in optimizer.param_groups:
                group['lr'] *= self.lr_scale
                group['weight_decay'] = group['lr'] * self.wd_pct
                if 'initial_lr' in group:
                    group['initial_lr'] *= self.lr_scale
                log.info(
                    f"Set LR and WD to {group['lr']}, {group['weight_decay']}")

        for scheduler in state.schedulers:
            scheduler.base_lrs = [
                self.lr_scale * lr for lr in scheduler.base_lrs
            ]


class LayerFreezing(Callback):
    """LayerFreezing.

    This callback can be applied upon resuming a model checkpoint. Upon
    fit_start it freeze the layers specified in `layer_names`. If using
    activation checkpointing, please set the
    `activation_checkpointing_reentrant` flag in `fsdp_config` to false.

    Args:
        layer_names (float): Names of layers to freeze.
    """

    def __init__(self, layer_names: List[str]):
        self.layer_names = set(layer_names)

    def fit_start(self, state: State, logger: Logger) -> None:
        del logger  # unused

        model_layers = set(name for name, _ in state.model.named_parameters())
        for layer in self.layer_names:
            if layer not in model_layers:
                raise Exception(
                    f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}'
                )

        successful_freeze = False
        for name, p in state.model.named_parameters():
            if p.requires_grad and name in self.layer_names:
                p.requires_grad = False
                log.debug(f'Froze layer: {name}\nParam: {p}')
                successful_freeze = True

        if not successful_freeze:
            raise Exception(
                f"Tried to run LayerFreezing but didn't freeze any layers")