File size: 10,939 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import os
import torch
import transformers
import os
from starvector.util import checkpoint_key
import glob
import shutil
import builtins
import datetime
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import (
    MixedPrecision,
    ShardingStrategy,
)
import functools
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import PrecisionType

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)


from transformers import (
    AutoConfig, 
    AutoModelForCausalLM
)
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM
from starvector.train.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict


def is_deepspeed(checkpoint_dir):
    # Check zero_to_fp32.py file (generated only in deepspeed training)
    return os.path.exists(os.path.join(checkpoint_dir, 'zero_to_fp32.py'))

def consolidate_deepspeed_checkpoint(checkpoint_dir):
    path_state_dict = os.path.join(checkpoint_dir, 'weights.pt')
    if not os.path.exists(path_state_dict):
        convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, path_state_dict)

def load_checkpoint(model, checkpoint_dir):
    candidate_files = ['weights.pt', 'pytorch_model.bin', 'model.safetensors']
    
    # Determine the correct file to load
    for candidate in candidate_files:
        path_state_dict = os.path.join(checkpoint_dir, candidate)
        if os.path.exists(path_state_dict):
            break
    else:
        raise FileNotFoundError(f"No checkpoint file found in {checkpoint_dir}")

    # Load the state dict based on file type
    if path_state_dict.endswith('.safetensors'):
        import safetensors.torch
        state_dict = safetensors.torch.load_file(path_state_dict)
    else:
        state_dict = torch.load(path_state_dict)

    # Handle FSDP or module prefix
    if list(model.state_dict().keys())[0].startswith('module'):
        new_state_dict = {'module.' + key: val for key, val in state_dict.items()}
    else:
        new_state_dict = state_dict

    # Handle Tied Weights
    if hasattr(model, 'tie_weights'):
        # Remove the lm_head.weight key if it exists and tie_weights will handle it
        new_state_dict.pop("model.svg_transformer.transformer.lm_head.weight", None)

    # Load the state dict into the model with strict=False to ignore missing keys
    model.load_state_dict(new_state_dict, strict=False)  # Allow missing keys

    # Ensure weights are tied after loading
    model.tie_weights()  # This method should tie the weights internally

    return model

from transformers import (
    AutoConfig, 
    AutoModelForCausalLM
)
from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM

def push_model_to_hub(model, new_model_name, tokenizer, processor):
    # Register the model for HF
    AutoConfig.register("starvector", StarVectorConfig)
    AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM)
    StarVectorConfig.register_for_auto_class()
    StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM")
    model.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
    tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True)
    processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True)

# push_model_to_hub(self.model, new_model_name, self.tokenizer, self.processor)
def save_checkpoint(accelerator, model, global_step, logging_dir, checkpoint_limit):
    print("Saving checkpoint! Global Step: " + str(global_step))
    save_checkpoint_dir = os.path.join(logging_dir, f"checkpoint-{global_step}")
    os.makedirs(save_checkpoint_dir, exist_ok=True)
    accelerator.wait_for_everyone()
    accelerator.save_state(save_checkpoint_dir)

    chkp_dirs = sorted(glob.glob(os.path.join(logging_dir, "checkpoint-*")), key = checkpoint_key)
    chkp_to_remove = chkp_dirs[:-checkpoint_limit]
    for chkp in chkp_to_remove:
        if accelerator.is_main_process:
            try:
                shutil.rmtree(chkp)
            except:
                print("could not remove checkpoint")
    print(f"Saved state to {save_checkpoint_dir}")

def push_model_to_hub(model, new_model_name, hf_token=None):
    tokenizer = model.model.svg_transformer.tokenizer
    # Register the model for HF
    AutoConfig.register("starvector", StarVectorConfig)
    AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM)
    StarVectorConfig.register_for_auto_class()
    StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM")


    model.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
    tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)
    
    processor = model.model.image_encoder.processor
    from starvector.data.base import ImageTrainProcessor
    if not isinstance(processor, ImageTrainProcessor):
        processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token)

def get_optimizer(config, model):
    optimizer = config.training.optimizer
    if optimizer == "adamw":
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.training.lr,
            betas=(config.training.adam_beta1, config.training.adam_beta2),
            weight_decay=config.training.adam_weight_decay,
            eps=config.training.adam_epsilon,
        )
    elif optimizer == "adafactor":
        optimizer = transformers.Adafactor(
            model.parameters(),
            lr=config.training.lr,
            relative_step=False,
            scale_parameter=False,
        )
    else:
        raise ValueError(f"Optimizer {optimizer} not supported")
    return optimizer


def init_distributed_mode(config):
    """
    Initializes torch distributed
    """
    assert all(var in os.environ for var in ['WORLD_SIZE', 'LOCAL_RANK', 'RANK'])

    world_size = int(os.environ['WORLD_SIZE'])
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ['LOCAL_RANK'])
    dist_url = 'env://'

    torch.cuda.set_device(local_rank)
    dist_backend = 'nccl'
    print('| distributed init (rank {}): {}, gpu {}'.format(
        rank, dist_url, local_rank), flush=True)
    torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
                                         world_size=world_size, rank=rank)
    torch.distributed.barrier()
    print_only_on_master(rank == 0)

def print_only_on_master(is_master):
    """
    This function disables printing when not in master process
    """
    builtin_print = builtins.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        kwargs['flush'] = True
        if is_master or force:
            now = datetime.datetime.now().time()
            builtin_print('[{}] '.format(now), end='')  # print with time stamp
            builtin_print(*args, **kwargs)
    
    builtins.print = print
    
def setup_train_env_variables(config):
    """
    Set environment variables needed by FSDP and accelerate
    """
    mixed_precision = config.training.model_precision.lower()

    try:
        mixed_precision = PrecisionType(mixed_precision)
    except ValueError:
        raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")

    os.environ["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)

    if config.fsdp.enable:
        # We have to manually set some of the FSDP arguments as environment variables as these are not exposed by the FSDP Plugin API
        os.environ['ACCELERATE_USE_FSDP']="true"
        os.environ['FSDP_USE_ORIG_PARAMS']=str(config.fsdp.use_orig_params).lower()
        os.environ['FSDP_FORWARD_PREFETCH']=str(config.fsdp.forward_prefetch).lower()

        if config.fsdp.cpu_ram_efficient_loading and not config.fsdp.sync_module_states:
             raise ValueError("When using `fsdp.cpu_ram_efficient_loading` set `fsdp.sync_module_states` to `True`")

        os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING']=str(config.fsdp.cpu_ram_efficient_loading).lower()
        os.environ['FSDP_SYNC_MODULE_STATES']=str(config.fsdp.sync_module_states).lower()

def load_fsdp_plugin(config, model):
    if config.fsdp.enable:
        # get mixed precsion dtype
        mixed_precision_dtype = {
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
            "tf32": torch.float32,
        }[config.training.model_precision]

        fsdp_plugin = FullyShardedDataParallelPlugin(
            state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
            optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
            auto_wrap_policy=model.model.get_fsdp_wrapping_policy(),
            mixed_precision_policy=MixedPrecision(
                param_dtype=mixed_precision_dtype,
                reduce_dtype=mixed_precision_dtype,
                buffer_dtype=mixed_precision_dtype,
            ),
            sharding_strategy={
                "sdp": ShardingStrategy.SHARD_GRAD_OP,
                "ddp": ShardingStrategy.NO_SHARD,
                "fsdp": ShardingStrategy.FULL_SHARD,
                "hsdp": ShardingStrategy.HYBRID_SHARD,
            }[config.fsdp.sharding_strategy],
            backward_prefetch=config.fsdp.backward_prefetch,
            cpu_offload=config.fsdp.cpu_offload,
        )
    else:
        fsdp_plugin = None
    
    return fsdp_plugin


def apply_gradient_checkpointing(model):
    """ Apply gradient checkpointing to Transformer cls of the LLM """

    def check_fn(submodule):
        return isinstance(submodule, model.model.svg_transformer.transformer_layer_cls)

    apply_activation_checkpointing(
        model,
        checkpoint_wrapper_fn=functools.partial(
            checkpoint_wrapper,
            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
        ),
        check_fn=check_fn,
    )

    # Wait for all processes to finish
    torch.distributed.barrier()

    return model

def get_module_class_from_name(module, name):
    """
    Gets a class from a module by its name.

    Args:
        module (`torch.nn.Module`): The module to get the class from.
        name (`str`): The name of the class.
    """
    modules_children = list(module.children())
    if module.__class__.__name__ == name:
        return module.__class__
    elif len(modules_children) == 0:
        return
    else:
        for child_module in modules_children:
            module_class = get_module_class_from_name(child_module, name)
            if module_class is not None:
                return module_class