|
|
|
|
|
|
|
import logging |
|
from typing import List |
|
|
|
from composer.core import Callback, State |
|
from composer.loggers import Logger |
|
|
|
__all__ = [ |
|
'GlobalLRScaling', |
|
'LayerFreezing', |
|
] |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class GlobalLRScaling(Callback): |
|
"""GlobalLRScaling. |
|
|
|
This callback can be applied upon resuming a model checkpoint. Upon |
|
fit_start it will multiply the base LR by `lr_scale` and set the WD to be. |
|
|
|
`wd_pct` * `lr`. |
|
|
|
Args: |
|
lr_scale (float): Multiplicative factor to scale LR by |
|
wd_pct (float): Percentage of LR to set weight decay to. |
|
""" |
|
|
|
def __init__(self, lr_scale: float, wd_pct: float = 0.0): |
|
self.lr_scale = lr_scale |
|
self.wd_pct = wd_pct |
|
|
|
def fit_start(self, state: State, logger: Logger) -> None: |
|
del logger |
|
|
|
if hasattr(state, 'optimizer') and state.optimizers is None: |
|
raise Exception('No optimizers defined') |
|
for optimizer in state.optimizers: |
|
for group in optimizer.param_groups: |
|
group['lr'] *= self.lr_scale |
|
group['weight_decay'] = group['lr'] * self.wd_pct |
|
if 'initial_lr' in group: |
|
group['initial_lr'] *= self.lr_scale |
|
log.info( |
|
f"Set LR and WD to {group['lr']}, {group['weight_decay']}") |
|
|
|
for scheduler in state.schedulers: |
|
scheduler.base_lrs = [ |
|
self.lr_scale * lr for lr in scheduler.base_lrs |
|
] |
|
|
|
|
|
class LayerFreezing(Callback): |
|
"""LayerFreezing. |
|
|
|
This callback can be applied upon resuming a model checkpoint. Upon |
|
fit_start it freeze the layers specified in `layer_names`. If using |
|
activation checkpointing, please set the |
|
`activation_checkpointing_reentrant` flag in `fsdp_config` to false. |
|
|
|
Args: |
|
layer_names (float): Names of layers to freeze. |
|
""" |
|
|
|
def __init__(self, layer_names: List[str]): |
|
self.layer_names = set(layer_names) |
|
|
|
def fit_start(self, state: State, logger: Logger) -> None: |
|
del logger |
|
|
|
model_layers = set(name for name, _ in state.model.named_parameters()) |
|
for layer in self.layer_names: |
|
if layer not in model_layers: |
|
raise Exception( |
|
f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}' |
|
) |
|
|
|
successful_freeze = False |
|
for name, p in state.model.named_parameters(): |
|
if p.requires_grad and name in self.layer_names: |
|
p.requires_grad = False |
|
log.debug(f'Froze layer: {name}\nParam: {p}') |
|
successful_freeze = True |
|
|
|
if not successful_freeze: |
|
raise Exception( |
|
f"Tried to run LayerFreezing but didn't freeze any layers") |
|
|