import functools import math from typing import Any, Callable, Dict, List, Optional, Type, Union import torch from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_optimizer_state_dict, set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful from .parallel import ParallelBackendEnum from .utils.import_utils import is_bitsandbytes_available class OptimizerWrapper(Stateful): r""" Optimizer wrapper that: - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages - saves/loading optimizer state_dict at checkpoint """ def __init__( self, model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any], ) -> None: self.optimizer_cls = optimizer_cls self.optimizer_kwargs = optimizer_kwargs self.optimizers = [] self.model_parts = model_parts for model in self.model_parts: optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) self.optimizers.append(optimizer) def step(self) -> None: for optimizer in self.optimizers: optimizer.step() def zero_grad(self) -> None: for optimizer in self.optimizers: optimizer.zero_grad() def state_dict(self) -> Dict[str, Any]: func = functools.partial( get_optimizer_state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), ) return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: func = functools.partial( set_optimizer_state_dict, optim_state_dict=state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), ) list(map(func, self.model_parts, self.optimizers)) class SchedulerWrapper: def __init__( self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int ) -> None: self.schedulers = [] for optimizer in optimizers: self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)) def step(self) -> None: for scheduler in self.schedulers: scheduler.step() def get_last_lr(self) -> List[float]: # TODO(aryan): look into this later. Currently calling it leads to NCCL hang????? return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)} def get_lr_scheduler_state(self) -> Dict[str, Any]: state_dict = {} if len(self.schedulers) == 1: state_dict["lr_scheduler"] = self.schedulers[0] else: # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. # It should only support saving and loading a distributed checkpoint with the same number of pp ranks for idx, lr_scheduler in enumerate(self.schedulers): state_dict[f"lr_scheduler_{idx}"] = lr_scheduler return state_dict def get_optimizer( parallel_backend: ParallelBackendEnum, name: str, model_parts: List[torch.nn.Module], learning_rate: float = 1e-3, beta1: float = 0.9, beta2: float = 0.95, beta3: float = 0.999, epsilon: float = 1e-8, weight_decay: float = 1e-4, fused: bool = False, ) -> Union[torch.optim.Optimizer, OptimizerWrapper]: name = name.lower() _raise_errors_if_packages_not_available(name) if name == "adam": optimizer_cls = torch.optim.Adam optimizer_kwargs = { "lr": learning_rate, "betas": (beta1, beta2), "eps": epsilon, "weight_decay": weight_decay, "fused": fused, } elif name == "adamw": optimizer_cls = torch.optim.AdamW optimizer_kwargs = { "lr": learning_rate, "betas": (beta1, beta2), "eps": epsilon, "weight_decay": weight_decay, "fused": fused, } elif name == "adam-bnb": from bitsandbytes.optim import Adam optimizer_cls = Adam optimizer_kwargs = { "lr": learning_rate, "betas": (beta1, beta2), "eps": epsilon, "weight_decay": weight_decay, } elif name == "adamw-bnb": from bitsandbytes.optim import AdamW optimizer_cls = AdamW optimizer_kwargs = { "lr": learning_rate, "betas": (beta1, beta2), "eps": epsilon, "weight_decay": weight_decay, } elif name == "adam-bnb-8bit": from bitsandbytes.optim import Adam8bit optimizer_cls = Adam8bit optimizer_kwargs = { "lr": learning_rate, "betas": (beta1, beta2), "eps": epsilon, "weight_decay": weight_decay, } elif name == "adamw-bnb-8bit": from bitsandbytes.optim import AdamW8bit optimizer_cls = AdamW8bit optimizer_kwargs = { "lr": learning_rate, "betas": (beta1, beta2), "eps": epsilon, "weight_decay": weight_decay, } # TODO(aryan): handle bitsandbytes and torchao else: raise ValueError(f"Unsupported optimizer: {name}") if parallel_backend == ParallelBackendEnum.ACCELERATE: return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs) elif parallel_backend == ParallelBackendEnum.PTD: return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs) def get_optimizer_accelerate( model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] ) -> torch.optim.Optimizer: params = [param for model in model_parts for param in model.parameters() if param.requires_grad] optimizer = optimizer_cls(params, **optimizer_kwargs) return optimizer def get_optimizer_ptd( model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] ) -> OptimizerWrapper: return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs) def get_lr_scheduler( parallel_backend: ParallelBackendEnum, name: str, optimizer: Union[torch.optim.Optimizer, OptimizerWrapper], step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, power: float = 1.0, lr_init: float = 1e-3, lr_end: float = 1e-7, last_epoch: int = -1, ) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]: name = name.lower() if name == "constant": scheduler_lambda_fn = get_constant_schedule() elif name == "constant_with_warmup": scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps) elif name == "piecewise_constant": scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules) elif name == "linear": scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps) elif name == "cosine": scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles) elif name == "cosine_with_restarts": scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup( num_warmup_steps, num_training_steps, num_cycles ) elif name == "polynomial": scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup( num_warmup_steps, num_training_steps, lr_init, lr_end, power ) else: raise ValueError(f"Unsupported scheduler: {name}") if parallel_backend == ParallelBackendEnum.ACCELERATE: return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch) elif parallel_backend == ParallelBackendEnum.PTD: return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch) def get_lr_scheduler_accelerate( optimizer: torch.optim.Optimizer, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1, ) -> torch.optim.lr_scheduler.LambdaLR: scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch) return scheduler def get_lr_scheduler_ptd( optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1 ) -> SchedulerWrapper: return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch) # ============================== # Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py # ============================== def get_constant_schedule() -> Callable[[int], float]: r""" Create a schedule with a constant learning rate, using the learning rate set in optimizer. """ def lr_lambda(current_step: int): return 1.0 return lr_lambda def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]: r""" Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. Args: num_warmup_steps (`int`): The number of steps for the warmup phase. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return lr_lambda def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]: r""" Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: step_rules (`string`): The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. """ rules_dict = {} rule_list = step_rules.split(",") for rule_str in rule_list[:-1]: value_str, steps_str = rule_str.split(":") steps = int(steps_str) value = float(value_str) rules_dict[steps] = value last_lr_multiple = float(rule_list[-1]) def create_rules_function(rules_dict, last_lr_multiple): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): if steps < sorted_step: return rules_dict[sorted_steps[i]] return last_lr_multiple return rule_func rules_func = create_rules_function(rules_dict, last_lr_multiple) return rules_func def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]: r""" Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return lr_lambda def get_cosine_schedule_with_warmup( num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, ) -> Callable[[int], float]: r""" Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_periods (`float`, *optional*, defaults to 0.5): The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine). """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) return lr_lambda def get_cosine_with_hard_restarts_schedule_with_warmup( num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, ) -> Callable[[int], float]: r""" Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`int`, *optional*, defaults to 1): The number of hard restarts to use. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) return lr_lambda def get_polynomial_decay_schedule_with_warmup( num_warmup_steps: int, num_training_steps: int, lr_init: float, lr_end: float = 1e-7, power: float = 1.0, ) -> Callable[[int], float]: r""" Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. lr_end (`float`, *optional*, defaults to 1e-7): The end LR. power (`float`, *optional*, defaults to 1.0): Power factor. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 """ if not (lr_init > lr_end): raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init return lr_lambda def _raise_errors_if_packages_not_available(name: str) -> None: name_split = name.split("-") if len(name_split) < 2: return package_name = name_split[1] if package_name == "bnb": if not is_bitsandbytes_available(): raise ImportError( f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer." )