TongkunGuan's picture
Upload 94 files
841bef5 verified
import json
import os
import torch
import torch.nn as nn
import transformers
from transformers import Trainer, logging
from transformers.trainer import is_sagemaker_mp_enabled
logger = logging.get_logger(__name__)
def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
if var_name.startswith('internvl.'):
var_name = var_name[len('internvl.'):]
if var_name in ('query_tokens', 'logit_scale',):
return 0
if var_name.startswith('clip_projector.'):
return vit_num_max_layer
if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
var_name == 'text_projection':
return llama_num_max_layer
if var_name.startswith('vision_model.'):
if 'embeddings.' in var_name:
return 0
if 'layers.' in var_name:
var_name = var_name.split('layers.')[-1]
layer_id = int(var_name.split('.')[0])
return layer_id + 1
if var_name.startswith('qllama.'):
if 'embed_tokens' in var_name:
return 0
if 'layers.' in var_name:
var_name = var_name.split('layers.')[-1]
layer_id = int(var_name.split('.')[0])
return layer_id + 1
else:
return llama_num_max_layer
return 0
def param_classification(name):
if name.startswith('internvl.'):
name = name[len('internvl.'):]
if name in ['query_tokens', 'text_projection', 'logit_scale']:
return 'qllama'
elif name.startswith('vision_model.'):
return 'vit'
elif name.startswith('qllama.'):
return 'qllama'
elif name.startswith('clip_projector.'):
return 'vit'
elif name.startswith('clip_projector2.'):
return 'qllama'
elif name.startswith('itm_head.'):
return 'qllama'
else:
return 'other'
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
# import pdb; pdb.set_trace()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
parameter_groups = {}
try: # for stage2 model
vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
except: # for stage3 model
vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2
qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2
print('vit_num_layers:', vit_num_layers)
print('qllama_num_layers:', qllama_num_layers)
vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0))
print('vit_layer_decay_rate:', vit_layer_decay_rate)
print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
print('qllama_lr_scale:', qllama_lr_scale)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = self.args.weight_decay
cls = param_classification(name)
layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
if group_name not in parameter_groups:
if cls == 'vit':
scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
elif cls == 'qllama':
scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
scale = scale * qllama_lr_scale
else:
scale = 1.0
scale = min(1.0, scale)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.args.learning_rate,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank = torch.distributed.get_rank()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
print('Param groups = %s' % json.dumps(to_display, indent=2))
optimizer_grouped_parameters = list(parameter_groups.values())
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(module, 'weight', {'optim_bits': 32})
logger.debug(f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def create_optimizer_custom(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
# import pdb; pdb.set_trace()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
parameter_groups = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = self.args.weight_decay
if 'ocr_mlp' in name or 'upsample' in name:
group_name = '%s_%s' % ('modify', group_name)
elif 'vision_model' in name:
group_name = '%s_%s' % ('vit', group_name)
else:
group_name = '%s_%s' % ('base', group_name)
if group_name not in parameter_groups:
if 'ocr_mlp' in name or 'upsample' in name:
scale = 1.0
elif 'vision_model' in name:
scale = 0.05
else:
scale = 1.0
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.args.learning_rate,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank = torch.distributed.get_rank()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
print('Param groups = %s' % json.dumps(to_display, indent=2))
optimizer_grouped_parameters = list(parameter_groups.values())
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(module, 'weight', {'optim_bits': 32})
logger.debug(f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def replace_create_optimizer():
print('Replace original create_optimizer with custom create_optimizer')
# transformers.Trainer.create_optimizer = create_optimizer
transformers.Trainer.create_optimizer = create_optimizer_custom