Spaces:
Running
Running
import os | |
import torch | |
import transformers | |
import os | |
from starvector.util import checkpoint_key | |
import glob | |
import shutil | |
import builtins | |
import datetime | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig | |
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
from torch.distributed.fsdp import ( | |
MixedPrecision, | |
ShardingStrategy, | |
) | |
import functools | |
from accelerate import FullyShardedDataParallelPlugin | |
from accelerate.utils import PrecisionType | |
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
checkpoint_wrapper, | |
CheckpointImpl, | |
apply_activation_checkpointing, | |
) | |
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM | |
) | |
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM | |
from starvector.train.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict | |
def is_deepspeed(checkpoint_dir): | |
# Check zero_to_fp32.py file (generated only in deepspeed training) | |
return os.path.exists(os.path.join(checkpoint_dir, 'zero_to_fp32.py')) | |
def consolidate_deepspeed_checkpoint(checkpoint_dir): | |
path_state_dict = os.path.join(checkpoint_dir, 'weights.pt') | |
if not os.path.exists(path_state_dict): | |
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, path_state_dict) | |
def load_checkpoint(model, checkpoint_dir): | |
candidate_files = ['weights.pt', 'pytorch_model.bin', 'model.safetensors'] | |
# Determine the correct file to load | |
for candidate in candidate_files: | |
path_state_dict = os.path.join(checkpoint_dir, candidate) | |
if os.path.exists(path_state_dict): | |
break | |
else: | |
raise FileNotFoundError(f"No checkpoint file found in {checkpoint_dir}") | |
# Load the state dict based on file type | |
if path_state_dict.endswith('.safetensors'): | |
import safetensors.torch | |
state_dict = safetensors.torch.load_file(path_state_dict) | |
else: | |
state_dict = torch.load(path_state_dict) | |
# Handle FSDP or module prefix | |
if list(model.state_dict().keys())[0].startswith('module'): | |
new_state_dict = {'module.' + key: val for key, val in state_dict.items()} | |
else: | |
new_state_dict = state_dict | |
# Handle Tied Weights | |
if hasattr(model, 'tie_weights'): | |
# Remove the lm_head.weight key if it exists and tie_weights will handle it | |
new_state_dict.pop("model.svg_transformer.transformer.lm_head.weight", None) | |
# Load the state dict into the model with strict=False to ignore missing keys | |
model.load_state_dict(new_state_dict, strict=False) # Allow missing keys | |
# Ensure weights are tied after loading | |
model.tie_weights() # This method should tie the weights internally | |
return model | |
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM | |
) | |
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM | |
def push_model_to_hub(model, new_model_name, tokenizer, processor): | |
# Register the model for HF | |
AutoConfig.register("starvector", StarVectorConfig) | |
AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM) | |
StarVectorConfig.register_for_auto_class() | |
StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") | |
model.push_to_hub(new_model_name, commit_message=new_model_name, private=True) | |
tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True) | |
processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True) | |
# push_model_to_hub(self.model, new_model_name, self.tokenizer, self.processor) | |
def save_checkpoint(accelerator, model, global_step, logging_dir, checkpoint_limit): | |
print("Saving checkpoint! Global Step: " + str(global_step)) | |
save_checkpoint_dir = os.path.join(logging_dir, f"checkpoint-{global_step}") | |
os.makedirs(save_checkpoint_dir, exist_ok=True) | |
accelerator.wait_for_everyone() | |
accelerator.save_state(save_checkpoint_dir) | |
chkp_dirs = sorted(glob.glob(os.path.join(logging_dir, "checkpoint-*")), key = checkpoint_key) | |
chkp_to_remove = chkp_dirs[:-checkpoint_limit] | |
for chkp in chkp_to_remove: | |
if accelerator.is_main_process: | |
try: | |
shutil.rmtree(chkp) | |
except: | |
print("could not remove checkpoint") | |
print(f"Saved state to {save_checkpoint_dir}") | |
def push_model_to_hub(model, new_model_name, hf_token=None): | |
tokenizer = model.model.svg_transformer.tokenizer | |
# Register the model for HF | |
AutoConfig.register("starvector", StarVectorConfig) | |
AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM) | |
StarVectorConfig.register_for_auto_class() | |
StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") | |
model.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) | |
tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) | |
processor = model.model.image_encoder.processor | |
from starvector.data.base import ImageTrainProcessor | |
if not isinstance(processor, ImageTrainProcessor): | |
processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) | |
def get_optimizer(config, model): | |
optimizer = config.training.optimizer | |
if optimizer == "adamw": | |
optimizer = torch.optim.AdamW( | |
model.parameters(), | |
lr=config.training.lr, | |
betas=(config.training.adam_beta1, config.training.adam_beta2), | |
weight_decay=config.training.adam_weight_decay, | |
eps=config.training.adam_epsilon, | |
) | |
elif optimizer == "adafactor": | |
optimizer = transformers.Adafactor( | |
model.parameters(), | |
lr=config.training.lr, | |
relative_step=False, | |
scale_parameter=False, | |
) | |
else: | |
raise ValueError(f"Optimizer {optimizer} not supported") | |
return optimizer | |
def init_distributed_mode(config): | |
""" | |
Initializes torch distributed | |
""" | |
assert all(var in os.environ for var in ['WORLD_SIZE', 'LOCAL_RANK', 'RANK']) | |
world_size = int(os.environ['WORLD_SIZE']) | |
rank = int(os.environ["RANK"]) | |
local_rank = int(os.environ['LOCAL_RANK']) | |
dist_url = 'env://' | |
torch.cuda.set_device(local_rank) | |
dist_backend = 'nccl' | |
print('| distributed init (rank {}): {}, gpu {}'.format( | |
rank, dist_url, local_rank), flush=True) | |
torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url, | |
world_size=world_size, rank=rank) | |
torch.distributed.barrier() | |
print_only_on_master(rank == 0) | |
def print_only_on_master(is_master): | |
""" | |
This function disables printing when not in master process | |
""" | |
builtin_print = builtins.print | |
def print(*args, **kwargs): | |
force = kwargs.pop('force', False) | |
kwargs['flush'] = True | |
if is_master or force: | |
now = datetime.datetime.now().time() | |
builtin_print('[{}] '.format(now), end='') # print with time stamp | |
builtin_print(*args, **kwargs) | |
builtins.print = print | |
def setup_train_env_variables(config): | |
""" | |
Set environment variables needed by FSDP and accelerate | |
""" | |
mixed_precision = config.training.model_precision.lower() | |
try: | |
mixed_precision = PrecisionType(mixed_precision) | |
except ValueError: | |
raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") | |
os.environ["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) | |
if config.fsdp.enable: | |
# We have to manually set some of the FSDP arguments as environment variables as these are not exposed by the FSDP Plugin API | |
os.environ['ACCELERATE_USE_FSDP']="true" | |
os.environ['FSDP_USE_ORIG_PARAMS']=str(config.fsdp.use_orig_params).lower() | |
os.environ['FSDP_FORWARD_PREFETCH']=str(config.fsdp.forward_prefetch).lower() | |
if config.fsdp.cpu_ram_efficient_loading and not config.fsdp.sync_module_states: | |
raise ValueError("When using `fsdp.cpu_ram_efficient_loading` set `fsdp.sync_module_states` to `True`") | |
os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING']=str(config.fsdp.cpu_ram_efficient_loading).lower() | |
os.environ['FSDP_SYNC_MODULE_STATES']=str(config.fsdp.sync_module_states).lower() | |
def load_fsdp_plugin(config, model): | |
if config.fsdp.enable: | |
# get mixed precsion dtype | |
mixed_precision_dtype = { | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
"tf32": torch.float32, | |
}[config.training.model_precision] | |
fsdp_plugin = FullyShardedDataParallelPlugin( | |
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
auto_wrap_policy=model.model.get_fsdp_wrapping_policy(), | |
mixed_precision_policy=MixedPrecision( | |
param_dtype=mixed_precision_dtype, | |
reduce_dtype=mixed_precision_dtype, | |
buffer_dtype=mixed_precision_dtype, | |
), | |
sharding_strategy={ | |
"sdp": ShardingStrategy.SHARD_GRAD_OP, | |
"ddp": ShardingStrategy.NO_SHARD, | |
"fsdp": ShardingStrategy.FULL_SHARD, | |
"hsdp": ShardingStrategy.HYBRID_SHARD, | |
}[config.fsdp.sharding_strategy], | |
backward_prefetch=config.fsdp.backward_prefetch, | |
cpu_offload=config.fsdp.cpu_offload, | |
) | |
else: | |
fsdp_plugin = None | |
return fsdp_plugin | |
def apply_gradient_checkpointing(model): | |
""" Apply gradient checkpointing to Transformer cls of the LLM """ | |
def check_fn(submodule): | |
return isinstance(submodule, model.model.svg_transformer.transformer_layer_cls) | |
apply_activation_checkpointing( | |
model, | |
checkpoint_wrapper_fn=functools.partial( | |
checkpoint_wrapper, | |
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
), | |
check_fn=check_fn, | |
) | |
# Wait for all processes to finish | |
torch.distributed.barrier() | |
return model | |
def get_module_class_from_name(module, name): | |
""" | |
Gets a class from a module by its name. | |
Args: | |
module (`torch.nn.Module`): The module to get the class from. | |
name (`str`): The name of the class. | |
""" | |
modules_children = list(module.children()) | |
if module.__class__.__name__ == name: | |
return module.__class__ | |
elif len(modules_children) == 0: | |
return | |
else: | |
for child_module in modules_children: | |
module_class = get_module_class_from_name(child_module, name) | |
if module_class is not None: | |
return module_class |