import json import os import torch import torch.nn as nn import transformers from transformers import Trainer, logging from transformers.trainer import is_sagemaker_mp_enabled logger = logging.get_logger(__name__) def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer): if var_name.startswith('internvl.'): var_name = var_name[len('internvl.'):] if var_name in ('query_tokens', 'logit_scale',): return 0 if var_name.startswith('clip_projector.'): return vit_num_max_layer if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \ var_name == 'text_projection': return llama_num_max_layer if var_name.startswith('vision_model.'): if 'embeddings.' in var_name: return 0 if 'layers.' in var_name: var_name = var_name.split('layers.')[-1] layer_id = int(var_name.split('.')[0]) return layer_id + 1 if var_name.startswith('qllama.'): if 'embed_tokens' in var_name: return 0 if 'layers.' in var_name: var_name = var_name.split('layers.')[-1] layer_id = int(var_name.split('.')[0]) return layer_id + 1 else: return llama_num_max_layer return 0 def param_classification(name): if name.startswith('internvl.'): name = name[len('internvl.'):] if name in ['query_tokens', 'text_projection', 'logit_scale']: return 'qllama' elif name.startswith('vision_model.'): return 'vit' elif name.startswith('qllama.'): return 'qllama' elif name.startswith('clip_projector.'): return 'vit' elif name.startswith('clip_projector2.'): return 'qllama' elif name.startswith('itm_head.'): return 'qllama' else: return 'other' def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ # import pdb; pdb.set_trace() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model parameter_groups = {} try: # for stage2 model vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2 qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2 except: # for stage3 model vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2 qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2 print('vit_num_layers:', vit_num_layers) print('qllama_num_layers:', qllama_num_layers) vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0)) qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0)) qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0)) print('vit_layer_decay_rate:', vit_layer_decay_rate) print('qllama_layer_decay_rate:', qllama_layer_decay_rate) print('qllama_lr_scale:', qllama_lr_scale) for name, param in opt_model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith('.bias'): group_name = 'no_decay' this_weight_decay = 0. else: group_name = 'decay' this_weight_decay = self.args.weight_decay cls = param_classification(name) layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers) group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name) if group_name not in parameter_groups: if cls == 'vit': scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1) elif cls == 'qllama': scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1) scale = scale * qllama_lr_scale else: scale = 1.0 scale = min(1.0, scale) parameter_groups[group_name] = { 'weight_decay': this_weight_decay, 'params': [], 'param_names': [], 'lr_scale': scale, 'group_name': group_name, 'lr': scale * self.args.learning_rate, } parameter_groups[group_name]['params'].append(param) parameter_groups[group_name]['param_names'].append(name) rank = torch.distributed.get_rank() if rank == 0: to_display = {} for key in parameter_groups: to_display[key] = { 'param_names': parameter_groups[key]['param_names'], 'lr_scale': parameter_groups[key]['lr_scale'], 'lr': parameter_groups[key]['lr'], 'weight_decay': parameter_groups[key]['weight_decay'], } print('Param groups = %s' % json.dumps(to_display, indent=2)) optimizer_grouped_parameters = list(parameter_groups.values()) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == 'Adam8bit': import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') manager.register_module_override(module, 'weight', {'optim_bits': 32}) logger.debug(f'bitsandbytes: will optimize {module} in fp32') logger.info(f'skipped: {skipped / 2 ** 20}M params') if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer def create_optimizer_custom(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ # import pdb; pdb.set_trace() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model parameter_groups = {} for name, param in opt_model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith('.bias'): group_name = 'no_decay' this_weight_decay = 0. else: group_name = 'decay' this_weight_decay = self.args.weight_decay if 'ocr_mlp' in name or 'upsample' in name: group_name = '%s_%s' % ('modify', group_name) elif 'vision_model' in name: group_name = '%s_%s' % ('vit', group_name) else: group_name = '%s_%s' % ('base', group_name) if group_name not in parameter_groups: if 'ocr_mlp' in name or 'upsample' in name: scale = 1.0 elif 'vision_model' in name: scale = 0.05 else: scale = 1.0 parameter_groups[group_name] = { 'weight_decay': this_weight_decay, 'params': [], 'param_names': [], 'lr_scale': scale, 'group_name': group_name, 'lr': scale * self.args.learning_rate, } parameter_groups[group_name]['params'].append(param) parameter_groups[group_name]['param_names'].append(name) rank = torch.distributed.get_rank() if rank == 0: to_display = {} for key in parameter_groups: to_display[key] = { 'param_names': parameter_groups[key]['param_names'], 'lr_scale': parameter_groups[key]['lr_scale'], 'lr': parameter_groups[key]['lr'], 'weight_decay': parameter_groups[key]['weight_decay'], } print('Param groups = %s' % json.dumps(to_display, indent=2)) optimizer_grouped_parameters = list(parameter_groups.values()) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == 'Adam8bit': import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') manager.register_module_override(module, 'weight', {'optim_bits': 32}) logger.debug(f'bitsandbytes: will optimize {module} in fp32') logger.info(f'skipped: {skipped / 2 ** 20}M params') if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer def replace_create_optimizer(): print('Replace original create_optimizer with custom create_optimizer') # transformers.Trainer.create_optimizer = create_optimizer transformers.Trainer.create_optimizer = create_optimizer_custom