File size: 2,919 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List
from composer.core import Callback, State
from composer.loggers import Logger
__all__ = [
'GlobalLRScaling',
'LayerFreezing',
]
log = logging.getLogger(__name__)
class GlobalLRScaling(Callback):
"""GlobalLRScaling.
This callback can be applied upon resuming a model checkpoint. Upon
fit_start it will multiply the base LR by `lr_scale` and set the WD to be.
`wd_pct` * `lr`.
Args:
lr_scale (float): Multiplicative factor to scale LR by
wd_pct (float): Percentage of LR to set weight decay to.
"""
def __init__(self, lr_scale: float, wd_pct: float = 0.0):
self.lr_scale = lr_scale
self.wd_pct = wd_pct
def fit_start(self, state: State, logger: Logger) -> None:
del logger # unused
if hasattr(state, 'optimizer') and state.optimizers is None:
raise Exception('No optimizers defined')
for optimizer in state.optimizers:
for group in optimizer.param_groups:
group['lr'] *= self.lr_scale
group['weight_decay'] = group['lr'] * self.wd_pct
if 'initial_lr' in group:
group['initial_lr'] *= self.lr_scale
log.info(
f"Set LR and WD to {group['lr']}, {group['weight_decay']}")
for scheduler in state.schedulers:
scheduler.base_lrs = [
self.lr_scale * lr for lr in scheduler.base_lrs
]
class LayerFreezing(Callback):
"""LayerFreezing.
This callback can be applied upon resuming a model checkpoint. Upon
fit_start it freeze the layers specified in `layer_names`. If using
activation checkpointing, please set the
`activation_checkpointing_reentrant` flag in `fsdp_config` to false.
Args:
layer_names (float): Names of layers to freeze.
"""
def __init__(self, layer_names: List[str]):
self.layer_names = set(layer_names)
def fit_start(self, state: State, logger: Logger) -> None:
del logger # unused
model_layers = set(name for name, _ in state.model.named_parameters())
for layer in self.layer_names:
if layer not in model_layers:
raise Exception(
f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}'
)
successful_freeze = False
for name, p in state.model.named_parameters():
if p.requires_grad and name in self.layer_names:
p.requires_grad = False
log.debug(f'Froze layer: {name}\nParam: {p}')
successful_freeze = True
if not successful_freeze:
raise Exception(
f"Tried to run LayerFreezing but didn't freeze any layers")
|