import copy import torch import itertools from enum import Enum from uniperceiver.config import CfgNode from uniperceiver.utils.registry import Registry from uniperceiver.utils import comm from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union SOLVER_REGISTRY = Registry("SOLVER") SOLVER_REGISTRY.__doc__ = """ Registry for SOLVER. """ _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] _GradientClipper = Callable[[_GradientClipperInput], None] def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: def clip_grad_norm(p: _GradientClipperInput): torch.nn.utils.clip_grad_norm_(p, cfg.SOLVER.GRAD_CLIP, cfg.SOLVER.NORM_TYPE) def clip_grad_value(p: _GradientClipperInput): torch.nn.utils.clip_grad_value_(p, cfg.SOLVER.GRAD_CLIP) _GRADIENT_CLIP_TYPE_TO_CLIPPER = { 'value': clip_grad_value, 'norm': clip_grad_norm, } clipper = _GRADIENT_CLIP_TYPE_TO_CLIPPER[cfg.SOLVER.GRAD_CLIP_TYPE] if cfg.SOLVER.GRAD_CLIP_TYPE == 'value': return clipper, None else: return None, clipper def get_default_optimizer_params( model: torch.nn.Module, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None, ): if weight_decay_bias is None: weight_decay_bias = weight_decay norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() no_decay_list = {} if hasattr(model, 'no_weight_decay'): no_decay_list = model.no_weight_decay() for module_name, module in model.named_modules(): no_decay = False if module_name in no_decay_list: no_decay = True for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) schedule_params = { "lr": base_lr, "weight_decay": weight_decay, } if isinstance(module, norm_module_types): schedule_params["weight_decay"] = weight_decay_norm elif module_param_name == "bias": # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0 # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer # hyperparameters are by default exactly the same as for regular # weights. schedule_params["lr"] = base_lr * bias_lr_factor schedule_params["weight_decay"] = weight_decay_bias if no_decay or (module_param_name in no_decay_list): schedule_params["weight_decay"] = 0. if overrides is not None and module_param_name in overrides: schedule_params.update(overrides[module_param_name]) params += [ { "params": [value], "lr": schedule_params["lr"], "weight_decay": schedule_params["weight_decay"], } ] return params def get_layer_id(module_name, num_layers): """ Assign a parameter with its layer id modified from BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 """ if module_name.split('.')[0] in [ 'video_embed', 'token_embed', 'prompt_embed', 'visual_embed', 'cls_token' '' ]: return 0 elif module_name.startswith('encoder'): return int(module_name.split('.')[2]) + 1 elif module_name.startswith('predictor'): return num_layers else: raise NotImplementedError('please check this layer') def create_seperate_moe_param_groups( model, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, wg_lr_facetor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, weight_decay_embedding: Optional[float] = None, weight_decay_wg: Optional[float] = None, cfg: dict = None, ): try: from deepspeed.moe.utils import is_moe_param except: def is_moe_param(param: torch.Tensor) -> bool: if hasattr(param, "allreduce") and not param.allreduce: return True return False params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() num_layers = cfg.MODEL.BERT.NUM_HIDDEN_LAYERS + 1 layer_decay = cfg.SOLVER.LAYER_LR_DECAY layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) if weight_decay_bias is None: weight_decay_bias = weight_decay norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) no_decay_list = {} if hasattr(model, 'no_weight_decay'): no_decay_list = model.no_weight_decay() wg_list = {} if hasattr(model, 'expert_gate_group'): wg_list = model.expert_gate_group() for module_name, module in model.named_modules(): no_decay = False if module_name in no_decay_list: no_decay = True is_wg_param = False for wg_name in wg_list: if wg_name in module_name: is_wg_param = True continue for module_param_name, value in module.named_parameters(recurse=False): # layer_id = get_layer_id(module_name, num_layers) this_scale = layer_scales[ get_layer_id(module_name, num_layers)] if layer_decay < 1.0 else 1.0 # if isinstance(module, torch.nn.Embedding): # print(module_name, module_param_name) if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) schedule_params = { "lr": base_lr, "weight_decay": weight_decay, "moe": False, } if is_moe_param(value): schedule_params['moe'] = True if no_decay or (module_param_name in no_decay_list): schedule_params["weight_decay"] = 0. elif is_wg_param and isinstance( module, torch.nn.Linear) and module_param_name != "bias": # only add linear weights in gate function schedule_params["lr"] = base_lr * wg_lr_facetor schedule_params["weight_decay"] = weight_decay_wg elif isinstance(module, torch.nn.Embedding): schedule_params['weight_decay'] = weight_decay_embedding elif isinstance(module, norm_module_types): if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias": # ln bias use the same params as linear bias schedule_params["lr"] = base_lr * bias_lr_factor schedule_params['weight_decay'] = weight_decay_bias else: schedule_params['weight_decay'] = weight_decay_norm elif module_param_name == "bias" or value.ndim == 1: schedule_params["lr"] = base_lr * bias_lr_factor schedule_params['weight_decay'] = weight_decay_bias params += [{ "params": [value], "lr": max(schedule_params["lr"] * this_scale, cfg.LR_SCHEDULER.get('MIN_LR', 1e-6)), "moe": schedule_params['moe'], "weight_decay": schedule_params["weight_decay"], "name": f'{module_name}.{module_param_name}' }] return params def create_group_moe_param_groups( model, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, wg_lr_facetor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, weight_decay_embedding: Optional[float] = None, weight_decay_wg: Optional[float] = None, cfg: dict = None, ): from deepspeed.moe.utils import is_moe_param # params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() if weight_decay_bias is None: weight_decay_bias = weight_decay norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) group_params_dict = {} no_decay_list = {} if hasattr(model, 'no_weight_decay'): no_decay_list = model.no_weight_decay() wg_list = {} if hasattr(model, 'expert_gate_group'): wg_list = model.expert_gate_group() for module_name, module in model.named_modules(): no_decay = False if module_name in no_decay_list: no_decay = True is_wg_param = False for wg_name in wg_list: if wg_name in module_name: is_wg_param = True continue for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) # default setting lr_of_this_param = base_lr wd_of_this_param = weight_decay moe_of_this_param = False if is_moe_param(value): moe_of_this_param = True if no_decay or (module_param_name in no_decay_list): wd_of_this_param = 0. elif is_wg_param and isinstance( module, torch.nn.Linear) and module_param_name != "bias": # only add linear weights in gate function lr_of_this_param = base_lr * wg_lr_facetor wd_of_this_param = weight_decay_wg elif isinstance(module, torch.nn.Embedding): wd_of_this_param = weight_decay_embedding elif isinstance(module, norm_module_types): if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias": # ln bias uses the same params as linear bias lr_of_this_param = base_lr * bias_lr_factor wd_of_this_param = weight_decay_bias else: wd_of_this_param = weight_decay_norm elif module_param_name == "bias": lr_of_this_param = base_lr * bias_lr_factor wd_of_this_param = weight_decay_bias param_group_name = f'lr_{lr_of_this_param}_wd_{wd_of_this_param}_moe_{moe_of_this_param}' if param_group_name not in group_params_dict: group_params_dict[param_group_name] = { 'params': [], "lr": lr_of_this_param, "weight_decay": wd_of_this_param, 'moe': moe_of_this_param, 'name': param_group_name, 'params_name': [], } group_params_dict[param_group_name]['params'].append(value) group_params_dict[param_group_name]['params_name'].append( f'{module_name}.{module_param_name}') valid_params_groups = list(group_params_dict.values()) return valid_params_groups def create_moe_param_groups( model, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, wg_lr_facetor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, weight_decay_embedding: Optional[float] = None, weight_decay_wg: Optional[float] = None, ): from deepspeed.moe.utils import is_moe_param ''' name: ''' if weight_decay_bias is None: weight_decay_bias = weight_decay norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) if weight_decay_embedding == 0.0: norm_module_types = norm_module_types + (torch.nn.Embedding, ) else: # if weight_decay_embedding is not 0.0, we set its weight_decay as normal weights # assert weight_decay_embedding == weight_decay pass params_with_weight_decay = { 'params': [], 'name': 'weight_decay_params', 'params_name': [], } params_without_weight_decay = { 'params': [], "weight_decay": 0.0, 'name': 'without_weight_decay_params', 'params_name': [], } bias_params = { 'params': [], "lr": base_lr * bias_lr_factor, "weight_decay": weight_decay_bias, 'name': 'bias_params', 'params_name': [], } wg_params = { 'params': [], "lr": base_lr * wg_lr_facetor, "weight_decay": weight_decay_wg, 'name': 'wg_params', 'params_name': [], } norm_params = { 'params': [], "weight_decay": weight_decay_norm, 'name': 'norm_params', 'params_name': [], } moe_params_with_weight_decay = { 'params': [], 'moe': True, 'name': 'weight_decay_moe_params', 'params_name': [], } moe_params_without_weight_decay = { 'params': [], "weight_decay": 0.0, 'moe': True, 'name': 'without_weight_decay_moe_params', 'params_name': [], } moe_bias_params = { 'params': [], "lr": base_lr * bias_lr_factor, "weight_decay": weight_decay_bias, 'moe': True, 'name': 'bias_moe_params', 'params_name': [], } moe_norm_params = { 'params': [], "weight_decay": weight_decay_norm, 'moe': True, 'name': 'norm_moe_params', 'params_name': [], } params_groups = [ params_with_weight_decay, params_without_weight_decay, norm_params, bias_params, wg_params, \ moe_params_with_weight_decay, moe_params_without_weight_decay, moe_norm_params, moe_bias_params ] no_decay_list = {} if hasattr(model, 'no_weight_decay'): no_decay_list = model.no_weight_decay() wg_list = {} if hasattr(model, 'expert_gate_group'): wg_list = model.expert_gate_group() memo: Set[torch.nn.parameter.Parameter] = set() for module_name, module in model.named_modules(): no_decay = False if module_name in no_decay_list: no_decay = True is_wg_param = False for wg_name in wg_list: if wg_name in module_name: is_wg_param = True continue for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) if is_moe_param(value): if no_decay or (module_param_name in no_decay_list): moe_params_without_weight_decay['params'].append(value) elif isinstance(module, norm_module_types): moe_norm_params['params'].append(value) elif module_param_name == "bias": moe_bias_params['params'].append(value) else: moe_params_with_weight_decay['params'].append(value) else: if no_decay or (module_param_name in no_decay_list): params_without_weight_decay['params'].append(value) params_without_weight_decay['params_name'].append(f'{module_name}.{module_param_name}') elif is_wg_param and isinstance(module, torch.nn.Linear) and module_param_name != "bias": # only add linear weights in gate function wg_params['params'].append(value) wg_params['params_name'].append( f'{module_name}.{module_param_name}') elif isinstance(module, norm_module_types): norm_params['params'].append(value) norm_params['params_name'].append( f'{module_name}.{module_param_name}') elif module_param_name == "bias": bias_params['params'].append(value) bias_params['params_name'].append( f'{module_name}.{module_param_name}') else: params_with_weight_decay['params'].append(value) params_with_weight_decay['params_name'].append( f'{module_name}.{module_param_name}') valid_params_groups = [ group for group in params_groups if len(group['params']) > 0 ] return valid_params_groups def _generate_optimizer_class_with_gradient_clipping( optimizer: Type[torch.optim.Optimizer], *, per_param_clipper: Optional[_GradientClipper] = None, global_clipper: Optional[_GradientClipper] = None, ) -> Type[torch.optim.Optimizer]: """ Dynamically creates a new type that inherits the type of a given instance and overrides the `step` method to add gradient clipping """ assert ( per_param_clipper is None or global_clipper is None ), "Not allowed to use both per-parameter clipping and global clipping" def optimizer_wgc_step(self, closure=None): if per_param_clipper is not None: for group in self.param_groups: for p in group["params"]: per_param_clipper(p) else: # global clipper for future use with detr # (https://github.com/facebookresearch/detr/pull/287) all_params = itertools.chain(*[g["params"] for g in self.param_groups]) norm_before_clip = global_clipper(all_params) super(type(self), self).step(closure) OptimizerWithGradientClip = type( optimizer.__name__ + "WithGradientClip", (optimizer,), {"step": optimizer_wgc_step}, ) return OptimizerWithGradientClip def maybe_add_gradient_clipping( cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] ) -> Type[torch.optim.Optimizer]: """ If gradient clipping is enabled through config options, wraps the existing optimizer type to become a new dynamically created class OptimizerWithGradientClip that inherits the given optimizer and overrides the `step` method to include gradient clipping. Args: cfg: CfgNode, configuration options optimizer: type. A subclass of torch.optim.Optimizer Return: type: either the input `optimizer` (if gradient clipping is disabled), or a subclass of it with gradient clipping included in the `step` method. """ if cfg.SOLVER.GRAD_CLIP <= 0: return optimizer if isinstance(optimizer, torch.optim.Optimizer): optimizer_type = type(optimizer) else: assert issubclass(optimizer, torch.optim.Optimizer), optimizer optimizer_type = optimizer per_param_clipper, global_clipper = _create_gradient_clipper(cfg) OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( optimizer_type, per_param_clipper=per_param_clipper, global_clipper=global_clipper ) if isinstance(optimizer, torch.optim.Optimizer): optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended return optimizer else: return OptimizerWithGradientClip def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: """ Build an optimizer from config. """ # params = get_default_optimizer_params( # model, # base_lr=cfg.SOLVER.BASE_LR, # weight_decay=cfg.SOLVER.WEIGHT_DECAY, # weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, # bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, # weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, # ) params = create_seperate_moe_param_groups( model, base_lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, wg_lr_facetor=cfg.SOLVER.WG_LR_FACTOR, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, weight_decay_embedding=cfg.SOLVER.WEIGHT_DECAY_EMBEDDING, weight_decay_wg=cfg.SOLVER.WEIGHT_DECAY_WG, cfg=cfg, ) if cfg.SOLVER.NAME == 'LAMB': from uniperceiver.optim import LAMB optimizer = LAMB( params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS, eps=cfg.SOLVER.EPS, weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) else: optimizer = torch.optim.AdamW( params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS, eps=cfg.SOLVER.EPS, weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) # optimizer = SOLVER_REGISTRY.get(cfg.SOLVER.NAME) # return maybe_add_gradient_clipping(cfg, optimizer)(cfg, params) return optimizer