hz2475's picture
init
72f684c
import os
import torch
import transformers
import os
from starvector.util import checkpoint_key
import glob
import shutil
import builtins
import datetime
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import (
MixedPrecision,
ShardingStrategy,
)
import functools
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import PrecisionType
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from transformers import (
AutoConfig,
AutoModelForCausalLM
)
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM
from starvector.train.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
def is_deepspeed(checkpoint_dir):
# Check zero_to_fp32.py file (generated only in deepspeed training)
return os.path.exists(os.path.join(checkpoint_dir, 'zero_to_fp32.py'))
def consolidate_deepspeed_checkpoint(checkpoint_dir):
path_state_dict = os.path.join(checkpoint_dir, 'weights.pt')
if not os.path.exists(path_state_dict):
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, path_state_dict)
def load_checkpoint(model, checkpoint_dir):
candidate_files = ['weights.pt', 'pytorch_model.bin', 'model.safetensors']
# Determine the correct file to load
for candidate in candidate_files:
path_state_dict = os.path.join(checkpoint_dir, candidate)
if os.path.exists(path_state_dict):
break
else:
raise FileNotFoundError(f"No checkpoint file found in {checkpoint_dir}")
# Load the state dict based on file type
if path_state_dict.endswith('.safetensors'):
import safetensors.torch
state_dict = safetensors.torch.load_file(path_state_dict)
else:
state_dict = torch.load(path_state_dict)
# Handle FSDP or module prefix
if list(model.state_dict().keys())[0].startswith('module'):
new_state_dict = {'module.' + key: val for key, val in state_dict.items()}
else:
new_state_dict = state_dict
# Handle Tied Weights
if hasattr(model, 'tie_weights'):
# Remove the lm_head.weight key if it exists and tie_weights will handle it
new_state_dict.pop("model.svg_transformer.transformer.lm_head.weight", None)
# Load the state dict into the model with strict=False to ignore missing keys
model.load_state_dict(new_state_dict, strict=False) # Allow missing keys
# Ensure weights are tied after loading
model.tie_weights() # This method should tie the weights internally
return model
from transformers import (
AutoConfig,
AutoModelForCausalLM
)
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM
def push_model_to_hub(model, new_model_name, tokenizer, processor):
# Register the model for HF
AutoConfig.register("starvector", StarVectorConfig)
AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM)
StarVectorConfig.register_for_auto_class()
StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM")
model.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
# push_model_to_hub(self.model, new_model_name, self.tokenizer, self.processor)
def save_checkpoint(accelerator, model, global_step, logging_dir, checkpoint_limit):
print("Saving checkpoint! Global Step: " + str(global_step))
save_checkpoint_dir = os.path.join(logging_dir, f"checkpoint-{global_step}")
os.makedirs(save_checkpoint_dir, exist_ok=True)
accelerator.wait_for_everyone()
accelerator.save_state(save_checkpoint_dir)
chkp_dirs = sorted(glob.glob(os.path.join(logging_dir, "checkpoint-*")), key = checkpoint_key)
chkp_to_remove = chkp_dirs[:-checkpoint_limit]
for chkp in chkp_to_remove:
if accelerator.is_main_process:
try:
shutil.rmtree(chkp)
except:
print("could not remove checkpoint")
print(f"Saved state to {save_checkpoint_dir}")
def push_model_to_hub(model, new_model_name, hf_token=None):
tokenizer = model.model.svg_transformer.tokenizer
# Register the model for HF
AutoConfig.register("starvector", StarVectorConfig)
AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM)
StarVectorConfig.register_for_auto_class()
StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM")
model.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
processor = model.model.image_encoder.processor
from starvector.data.base import ImageTrainProcessor
if not isinstance(processor, ImageTrainProcessor):
processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
def get_optimizer(config, model):
optimizer = config.training.optimizer
if optimizer == "adamw":
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.training.lr,
betas=(config.training.adam_beta1, config.training.adam_beta2),
weight_decay=config.training.adam_weight_decay,
eps=config.training.adam_epsilon,
)
elif optimizer == "adafactor":
optimizer = transformers.Adafactor(
model.parameters(),
lr=config.training.lr,
relative_step=False,
scale_parameter=False,
)
else:
raise ValueError(f"Optimizer {optimizer} not supported")
return optimizer
def init_distributed_mode(config):
"""
Initializes torch distributed
"""
assert all(var in os.environ for var in ['WORLD_SIZE', 'LOCAL_RANK', 'RANK'])
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ["RANK"])
local_rank = int(os.environ['LOCAL_RANK'])
dist_url = 'env://'
torch.cuda.set_device(local_rank)
dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
rank, dist_url, local_rank), flush=True)
torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
world_size=world_size, rank=rank)
torch.distributed.barrier()
print_only_on_master(rank == 0)
def print_only_on_master(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
kwargs['flush'] = True
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def setup_train_env_variables(config):
"""
Set environment variables needed by FSDP and accelerate
"""
mixed_precision = config.training.model_precision.lower()
try:
mixed_precision = PrecisionType(mixed_precision)
except ValueError:
raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
os.environ["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
if config.fsdp.enable:
# We have to manually set some of the FSDP arguments as environment variables as these are not exposed by the FSDP Plugin API
os.environ['ACCELERATE_USE_FSDP']="true"
os.environ['FSDP_USE_ORIG_PARAMS']=str(config.fsdp.use_orig_params).lower()
os.environ['FSDP_FORWARD_PREFETCH']=str(config.fsdp.forward_prefetch).lower()
if config.fsdp.cpu_ram_efficient_loading and not config.fsdp.sync_module_states:
raise ValueError("When using `fsdp.cpu_ram_efficient_loading` set `fsdp.sync_module_states` to `True`")
os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING']=str(config.fsdp.cpu_ram_efficient_loading).lower()
os.environ['FSDP_SYNC_MODULE_STATES']=str(config.fsdp.sync_module_states).lower()
def load_fsdp_plugin(config, model):
if config.fsdp.enable:
# get mixed precsion dtype
mixed_precision_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float32,
}[config.training.model_precision]
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
auto_wrap_policy=model.model.get_fsdp_wrapping_policy(),
mixed_precision_policy=MixedPrecision(
param_dtype=mixed_precision_dtype,
reduce_dtype=mixed_precision_dtype,
buffer_dtype=mixed_precision_dtype,
),
sharding_strategy={
"sdp": ShardingStrategy.SHARD_GRAD_OP,
"ddp": ShardingStrategy.NO_SHARD,
"fsdp": ShardingStrategy.FULL_SHARD,
"hsdp": ShardingStrategy.HYBRID_SHARD,
}[config.fsdp.sharding_strategy],
backward_prefetch=config.fsdp.backward_prefetch,
cpu_offload=config.fsdp.cpu_offload,
)
else:
fsdp_plugin = None
return fsdp_plugin
def apply_gradient_checkpointing(model):
""" Apply gradient checkpointing to Transformer cls of the LLM """
def check_fn(submodule):
return isinstance(submodule, model.model.svg_transformer.transformer_layer_cls)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
check_fn=check_fn,
)
# Wait for all processes to finish
torch.distributed.barrier()
return model
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class