Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os | |
import torch | |
import torch.nn as nn | |
import transformers | |
from transformers import Trainer, logging | |
from transformers.trainer import is_sagemaker_mp_enabled | |
logger = logging.get_logger(__name__) | |
def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer): | |
if var_name.startswith('internvl.'): | |
var_name = var_name[len('internvl.'):] | |
if var_name in ('query_tokens', 'logit_scale',): | |
return 0 | |
if var_name.startswith('clip_projector.'): | |
return vit_num_max_layer | |
if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \ | |
var_name == 'text_projection': | |
return llama_num_max_layer | |
if var_name.startswith('vision_model.'): | |
if 'embeddings.' in var_name: | |
return 0 | |
if 'layers.' in var_name: | |
var_name = var_name.split('layers.')[-1] | |
layer_id = int(var_name.split('.')[0]) | |
return layer_id + 1 | |
if var_name.startswith('qllama.'): | |
if 'embed_tokens' in var_name: | |
return 0 | |
if 'layers.' in var_name: | |
var_name = var_name.split('layers.')[-1] | |
layer_id = int(var_name.split('.')[0]) | |
return layer_id + 1 | |
else: | |
return llama_num_max_layer | |
return 0 | |
def param_classification(name): | |
if name.startswith('internvl.'): | |
name = name[len('internvl.'):] | |
if name in ['query_tokens', 'text_projection', 'logit_scale']: | |
return 'qllama' | |
elif name.startswith('vision_model.'): | |
return 'vit' | |
elif name.startswith('qllama.'): | |
return 'qllama' | |
elif name.startswith('clip_projector.'): | |
return 'vit' | |
elif name.startswith('clip_projector2.'): | |
return 'qllama' | |
elif name.startswith('itm_head.'): | |
return 'qllama' | |
else: | |
return 'other' | |
def create_optimizer(self): | |
""" | |
Setup the optimizer. | |
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
""" | |
# import pdb; pdb.set_trace() | |
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
parameter_groups = {} | |
try: # for stage2 model | |
vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2 | |
qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2 | |
except: # for stage3 model | |
vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2 | |
qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2 | |
print('vit_num_layers:', vit_num_layers) | |
print('qllama_num_layers:', qllama_num_layers) | |
vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0)) | |
qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0)) | |
qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0)) | |
print('vit_layer_decay_rate:', vit_layer_decay_rate) | |
print('qllama_layer_decay_rate:', qllama_layer_decay_rate) | |
print('qllama_lr_scale:', qllama_lr_scale) | |
for name, param in opt_model.named_parameters(): | |
if not param.requires_grad: | |
continue # frozen weights | |
if len(param.shape) == 1 or name.endswith('.bias'): | |
group_name = 'no_decay' | |
this_weight_decay = 0. | |
else: | |
group_name = 'decay' | |
this_weight_decay = self.args.weight_decay | |
cls = param_classification(name) | |
layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers) | |
group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name) | |
if group_name not in parameter_groups: | |
if cls == 'vit': | |
scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1) | |
elif cls == 'qllama': | |
scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1) | |
scale = scale * qllama_lr_scale | |
else: | |
scale = 1.0 | |
scale = min(1.0, scale) | |
parameter_groups[group_name] = { | |
'weight_decay': this_weight_decay, | |
'params': [], | |
'param_names': [], | |
'lr_scale': scale, | |
'group_name': group_name, | |
'lr': scale * self.args.learning_rate, | |
} | |
parameter_groups[group_name]['params'].append(param) | |
parameter_groups[group_name]['param_names'].append(name) | |
rank = torch.distributed.get_rank() | |
if rank == 0: | |
to_display = {} | |
for key in parameter_groups: | |
to_display[key] = { | |
'param_names': parameter_groups[key]['param_names'], | |
'lr_scale': parameter_groups[key]['lr_scale'], | |
'lr': parameter_groups[key]['lr'], | |
'weight_decay': parameter_groups[key]['weight_decay'], | |
} | |
print('Param groups = %s' % json.dumps(to_display, indent=2)) | |
optimizer_grouped_parameters = list(parameter_groups.values()) | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
if optimizer_cls.__name__ == 'Adam8bit': | |
import bitsandbytes | |
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
skipped = 0 | |
for module in opt_model.modules(): | |
if isinstance(module, nn.Embedding): | |
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') | |
manager.register_module_override(module, 'weight', {'optim_bits': 32}) | |
logger.debug(f'bitsandbytes: will optimize {module} in fp32') | |
logger.info(f'skipped: {skipped / 2 ** 20}M params') | |
if is_sagemaker_mp_enabled(): | |
import smdistributed.modelparallel.torch as smp | |
self.optimizer = smp.DistributedOptimizer(self.optimizer) | |
return self.optimizer | |
def create_optimizer_custom(self): | |
""" | |
Setup the optimizer. | |
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
""" | |
# import pdb; pdb.set_trace() | |
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
parameter_groups = {} | |
for name, param in opt_model.named_parameters(): | |
if not param.requires_grad: | |
continue # frozen weights | |
if len(param.shape) == 1 or name.endswith('.bias'): | |
group_name = 'no_decay' | |
this_weight_decay = 0. | |
else: | |
group_name = 'decay' | |
this_weight_decay = self.args.weight_decay | |
if 'ocr_mlp' in name or 'upsample' in name: | |
group_name = '%s_%s' % ('modify', group_name) | |
elif 'vision_model' in name: | |
group_name = '%s_%s' % ('vit', group_name) | |
else: | |
group_name = '%s_%s' % ('base', group_name) | |
if group_name not in parameter_groups: | |
if 'ocr_mlp' in name or 'upsample' in name: | |
scale = 1.0 | |
elif 'vision_model' in name: | |
scale = 0.05 | |
else: | |
scale = 1.0 | |
parameter_groups[group_name] = { | |
'weight_decay': this_weight_decay, | |
'params': [], | |
'param_names': [], | |
'lr_scale': scale, | |
'group_name': group_name, | |
'lr': scale * self.args.learning_rate, | |
} | |
parameter_groups[group_name]['params'].append(param) | |
parameter_groups[group_name]['param_names'].append(name) | |
rank = torch.distributed.get_rank() | |
if rank == 0: | |
to_display = {} | |
for key in parameter_groups: | |
to_display[key] = { | |
'param_names': parameter_groups[key]['param_names'], | |
'lr_scale': parameter_groups[key]['lr_scale'], | |
'lr': parameter_groups[key]['lr'], | |
'weight_decay': parameter_groups[key]['weight_decay'], | |
} | |
print('Param groups = %s' % json.dumps(to_display, indent=2)) | |
optimizer_grouped_parameters = list(parameter_groups.values()) | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
if optimizer_cls.__name__ == 'Adam8bit': | |
import bitsandbytes | |
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
skipped = 0 | |
for module in opt_model.modules(): | |
if isinstance(module, nn.Embedding): | |
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') | |
manager.register_module_override(module, 'weight', {'optim_bits': 32}) | |
logger.debug(f'bitsandbytes: will optimize {module} in fp32') | |
logger.info(f'skipped: {skipped / 2 ** 20}M params') | |
if is_sagemaker_mp_enabled(): | |
import smdistributed.modelparallel.torch as smp | |
self.optimizer = smp.DistributedOptimizer(self.optimizer) | |
return self.optimizer | |
def replace_create_optimizer(): | |
print('Replace original create_optimizer with custom create_optimizer') | |
# transformers.Trainer.create_optimizer = create_optimizer | |
transformers.Trainer.create_optimizer = create_optimizer_custom | |