Spaces:
Running
Running
File size: 10,939 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
import os
import torch
import transformers
import os
from starvector.util import checkpoint_key
import glob
import shutil
import builtins
import datetime
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import (
MixedPrecision,
ShardingStrategy,
)
import functools
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import PrecisionType
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from transformers import (
AutoConfig,
AutoModelForCausalLM
)
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM
from starvector.train.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
def is_deepspeed(checkpoint_dir):
# Check zero_to_fp32.py file (generated only in deepspeed training)
return os.path.exists(os.path.join(checkpoint_dir, 'zero_to_fp32.py'))
def consolidate_deepspeed_checkpoint(checkpoint_dir):
path_state_dict = os.path.join(checkpoint_dir, 'weights.pt')
if not os.path.exists(path_state_dict):
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, path_state_dict)
def load_checkpoint(model, checkpoint_dir):
candidate_files = ['weights.pt', 'pytorch_model.bin', 'model.safetensors']
# Determine the correct file to load
for candidate in candidate_files:
path_state_dict = os.path.join(checkpoint_dir, candidate)
if os.path.exists(path_state_dict):
break
else:
raise FileNotFoundError(f"No checkpoint file found in {checkpoint_dir}")
# Load the state dict based on file type
if path_state_dict.endswith('.safetensors'):
import safetensors.torch
state_dict = safetensors.torch.load_file(path_state_dict)
else:
state_dict = torch.load(path_state_dict)
# Handle FSDP or module prefix
if list(model.state_dict().keys())[0].startswith('module'):
new_state_dict = {'module.' + key: val for key, val in state_dict.items()}
else:
new_state_dict = state_dict
# Handle Tied Weights
if hasattr(model, 'tie_weights'):
# Remove the lm_head.weight key if it exists and tie_weights will handle it
new_state_dict.pop("model.svg_transformer.transformer.lm_head.weight", None)
# Load the state dict into the model with strict=False to ignore missing keys
model.load_state_dict(new_state_dict, strict=False) # Allow missing keys
# Ensure weights are tied after loading
model.tie_weights() # This method should tie the weights internally
return model
from transformers import (
AutoConfig,
AutoModelForCausalLM
)
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM
def push_model_to_hub(model, new_model_name, tokenizer, processor):
# Register the model for HF
AutoConfig.register("starvector", StarVectorConfig)
AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM)
StarVectorConfig.register_for_auto_class()
StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM")
model.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
# push_model_to_hub(self.model, new_model_name, self.tokenizer, self.processor)
def save_checkpoint(accelerator, model, global_step, logging_dir, checkpoint_limit):
print("Saving checkpoint! Global Step: " + str(global_step))
save_checkpoint_dir = os.path.join(logging_dir, f"checkpoint-{global_step}")
os.makedirs(save_checkpoint_dir, exist_ok=True)
accelerator.wait_for_everyone()
accelerator.save_state(save_checkpoint_dir)
chkp_dirs = sorted(glob.glob(os.path.join(logging_dir, "checkpoint-*")), key = checkpoint_key)
chkp_to_remove = chkp_dirs[:-checkpoint_limit]
for chkp in chkp_to_remove:
if accelerator.is_main_process:
try:
shutil.rmtree(chkp)
except:
print("could not remove checkpoint")
print(f"Saved state to {save_checkpoint_dir}")
def push_model_to_hub(model, new_model_name, hf_token=None):
tokenizer = model.model.svg_transformer.tokenizer
# Register the model for HF
AutoConfig.register("starvector", StarVectorConfig)
AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM)
StarVectorConfig.register_for_auto_class()
StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM")
model.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
processor = model.model.image_encoder.processor
from starvector.data.base import ImageTrainProcessor
if not isinstance(processor, ImageTrainProcessor):
processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
def get_optimizer(config, model):
optimizer = config.training.optimizer
if optimizer == "adamw":
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.training.lr,
betas=(config.training.adam_beta1, config.training.adam_beta2),
weight_decay=config.training.adam_weight_decay,
eps=config.training.adam_epsilon,
)
elif optimizer == "adafactor":
optimizer = transformers.Adafactor(
model.parameters(),
lr=config.training.lr,
relative_step=False,
scale_parameter=False,
)
else:
raise ValueError(f"Optimizer {optimizer} not supported")
return optimizer
def init_distributed_mode(config):
"""
Initializes torch distributed
"""
assert all(var in os.environ for var in ['WORLD_SIZE', 'LOCAL_RANK', 'RANK'])
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ["RANK"])
local_rank = int(os.environ['LOCAL_RANK'])
dist_url = 'env://'
torch.cuda.set_device(local_rank)
dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
rank, dist_url, local_rank), flush=True)
torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
world_size=world_size, rank=rank)
torch.distributed.barrier()
print_only_on_master(rank == 0)
def print_only_on_master(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
kwargs['flush'] = True
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def setup_train_env_variables(config):
"""
Set environment variables needed by FSDP and accelerate
"""
mixed_precision = config.training.model_precision.lower()
try:
mixed_precision = PrecisionType(mixed_precision)
except ValueError:
raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
os.environ["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
if config.fsdp.enable:
# We have to manually set some of the FSDP arguments as environment variables as these are not exposed by the FSDP Plugin API
os.environ['ACCELERATE_USE_FSDP']="true"
os.environ['FSDP_USE_ORIG_PARAMS']=str(config.fsdp.use_orig_params).lower()
os.environ['FSDP_FORWARD_PREFETCH']=str(config.fsdp.forward_prefetch).lower()
if config.fsdp.cpu_ram_efficient_loading and not config.fsdp.sync_module_states:
raise ValueError("When using `fsdp.cpu_ram_efficient_loading` set `fsdp.sync_module_states` to `True`")
os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING']=str(config.fsdp.cpu_ram_efficient_loading).lower()
os.environ['FSDP_SYNC_MODULE_STATES']=str(config.fsdp.sync_module_states).lower()
def load_fsdp_plugin(config, model):
if config.fsdp.enable:
# get mixed precsion dtype
mixed_precision_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float32,
}[config.training.model_precision]
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
auto_wrap_policy=model.model.get_fsdp_wrapping_policy(),
mixed_precision_policy=MixedPrecision(
param_dtype=mixed_precision_dtype,
reduce_dtype=mixed_precision_dtype,
buffer_dtype=mixed_precision_dtype,
),
sharding_strategy={
"sdp": ShardingStrategy.SHARD_GRAD_OP,
"ddp": ShardingStrategy.NO_SHARD,
"fsdp": ShardingStrategy.FULL_SHARD,
"hsdp": ShardingStrategy.HYBRID_SHARD,
}[config.fsdp.sharding_strategy],
backward_prefetch=config.fsdp.backward_prefetch,
cpu_offload=config.fsdp.cpu_offload,
)
else:
fsdp_plugin = None
return fsdp_plugin
def apply_gradient_checkpointing(model):
""" Apply gradient checkpointing to Transformer cls of the LLM """
def check_fn(submodule):
return isinstance(submodule, model.model.svg_transformer.transformer_layer_cls)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
check_fn=check_fn,
)
# Wait for all processes to finish
torch.distributed.barrier()
return model
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class |