TongkunGuan's picture
Upload 94 files
841bef5 verified
raw
history blame
9.84 kB
import json
import os
import torch
import torch.nn as nn
import transformers
from transformers import Trainer, logging
from transformers.trainer import is_sagemaker_mp_enabled
logger = logging.get_logger(__name__)
def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
if var_name.startswith('internvl.'):
var_name = var_name[len('internvl.'):]
if var_name in ('query_tokens', 'logit_scale',):
return 0
if var_name.startswith('clip_projector.'):
return vit_num_max_layer
if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
var_name == 'text_projection':
return llama_num_max_layer
if var_name.startswith('vision_model.'):
if 'embeddings.' in var_name:
return 0
if 'layers.' in var_name:
var_name = var_name.split('layers.')[-1]
layer_id = int(var_name.split('.')[0])
return layer_id + 1
if var_name.startswith('qllama.'):
if 'embed_tokens' in var_name:
return 0
if 'layers.' in var_name:
var_name = var_name.split('layers.')[-1]
layer_id = int(var_name.split('.')[0])
return layer_id + 1
else:
return llama_num_max_layer
return 0
def param_classification(name):
if name.startswith('internvl.'):
name = name[len('internvl.'):]
if name in ['query_tokens', 'text_projection', 'logit_scale']:
return 'qllama'
elif name.startswith('vision_model.'):
return 'vit'
elif name.startswith('qllama.'):
return 'qllama'
elif name.startswith('clip_projector.'):
return 'vit'
elif name.startswith('clip_projector2.'):
return 'qllama'
elif name.startswith('itm_head.'):
return 'qllama'
else:
return 'other'
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
# import pdb; pdb.set_trace()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
parameter_groups = {}
try: # for stage2 model
vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
except: # for stage3 model
vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2
qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2
print('vit_num_layers:', vit_num_layers)
print('qllama_num_layers:', qllama_num_layers)
vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0))
print('vit_layer_decay_rate:', vit_layer_decay_rate)
print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
print('qllama_lr_scale:', qllama_lr_scale)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = self.args.weight_decay
cls = param_classification(name)
layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
if group_name not in parameter_groups:
if cls == 'vit':
scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
elif cls == 'qllama':
scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
scale = scale * qllama_lr_scale
else:
scale = 1.0
scale = min(1.0, scale)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.args.learning_rate,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank = torch.distributed.get_rank()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
print('Param groups = %s' % json.dumps(to_display, indent=2))
optimizer_grouped_parameters = list(parameter_groups.values())
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(module, 'weight', {'optim_bits': 32})
logger.debug(f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def create_optimizer_custom(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
# import pdb; pdb.set_trace()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
parameter_groups = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = self.args.weight_decay
if 'ocr_mlp' in name or 'upsample' in name:
group_name = '%s_%s' % ('modify', group_name)
elif 'vision_model' in name:
group_name = '%s_%s' % ('vit', group_name)
else:
group_name = '%s_%s' % ('base', group_name)
if group_name not in parameter_groups:
if 'ocr_mlp' in name or 'upsample' in name:
scale = 1.0
elif 'vision_model' in name:
scale = 0.05
else:
scale = 1.0
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.args.learning_rate,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank = torch.distributed.get_rank()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
print('Param groups = %s' % json.dumps(to_display, indent=2))
optimizer_grouped_parameters = list(parameter_groups.values())
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(module, 'weight', {'optim_bits': 32})
logger.debug(f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def replace_create_optimizer():
print('Replace original create_optimizer with custom create_optimizer')
# transformers.Trainer.create_optimizer = create_optimizer
transformers.Trainer.create_optimizer = create_optimizer_custom