File size: 10,893 Bytes
35c1cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
from pathlib import Path
from datetime import datetime
import torch
import time
from collections import OrderedDict

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig,  # general model non-sharded, non-flattened params
    LocalStateDictConfig,  # flattened params, usable only by FSDP
    # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
)

from torch.distributed.checkpoint import (
    FileSystemReader,
    FileSystemWriter,
    save_state_dict,
    load_state_dict,
)
from torch.distributed.checkpoint.default_planner import (
    DefaultSavePlanner,
    DefaultLoadPlanner,
)


from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed.checkpoint as dist_cp
import torch.distributed as dist


import logging
logger = logging.getLogger(__name__)


def get_date_of_run():
    """create date and time for file save uniqueness

    example: 2022-05-07-08:31:12_PM'

    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    logger.info(f"--> current date and time of run = {date_of_run}")
    return date_of_run


# create singleton saving policies to avoid making over and over
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)


def load_model_sharded(model, rank, cfg):
    # torch.manual_seed(103)
    folder_name = (
        cfg.dist_checkpoint_root_folder
        + "/"
        + cfg.dist_checkpoint_folder
        + "-"
        + cfg.model_name
    )

    load_dir = Path.cwd() / folder_name

    if not load_dir.exists():
        if rank == 0:
            logger.info(f"No sharded_state_dict checkpoint directory found...skipping")
        return
    if rank == 0:
         logger.info(f"loading model from model path: {load_dir} ")
    reader = FileSystemReader(load_dir)

    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
        checkpoint = {"model": model.state_dict()}
        if rank == 0:
            ck = checkpoint.keys()
            logger.info(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
      
        dist_cp.load_state_dict(
            state_dict=checkpoint,
            storage_reader=reader,
        )
        if rank == 0:
            logger.info(f"checkpoint after load_state_dict()")
            ck = checkpoint.keys()
            logger.info(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
        model.load_state_dict(checkpoint["model"])
    if rank == 0:
        logger.info(f"Sharded state checkpoint loaded from {load_dir}")


def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
    """save model and optimizer via sharded_state_dict to save_dir"""
    
    folder_name = (
        cfg.dist_checkpoint_root_folder
        + "/"
        + cfg.dist_checkpoint_folder
        + "-"
        + cfg.model_name
    )

    save_dir = Path.cwd() / folder_name
    if rank == 0:
        logger.info(f"Saving model to {save_dir}")

    distributed_writer = dist_cp.FileSystemWriter(
        save_dir,
    )
    t0 = time.perf_counter()

    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
        
        state_dict = {"model": model.state_dict()}
        if optim is not None:
            state_dict["optim"] = FSDP.optim_state_dict(model, optim)

        dist_cp.save_state_dict(
            state_dict=state_dict,
            storage_writer=distributed_writer,
            planner=DefaultSavePlanner(),
            
        )
    dist.barrier()
    t1 = time.perf_counter()
    if rank == 0:
        logger.info(f"Sharded state checkpoint saved to {save_dir}")
        logger.info(
            f"Checkpoint Time = {t1-t0:.4f}\n"
        )
def save_model_checkpoint(

    model,

    optimizer,

    rank,

    cfg,

    epoch=1,

):
    """saving model via rank0 cpu streaming and full_state_dict"""

    with FSDP.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
    ):
        cpu_state = model.state_dict()

        logger.info(f"saving process: rank {rank}  done w model state_dict\n")
   

    if rank == 0:
        logger.info(f"--> saving model ...")
        # create save path
        folder_name = (
        cfg.dist_checkpoint_root_folder
        + "/"
        + cfg.dist_checkpoint_folder
        + "-"
        + cfg.model_name
        )
        save_dir = Path.cwd() / folder_name
        save_dir.mkdir(parents=True, exist_ok=True)
        save_name = cfg.model_name + "-" + str(epoch) + ".pt"
        save_full_path = str(save_dir) + "/" + save_name

        # save model
        torch.save(cpu_state, save_full_path)

        
        logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")

def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"):
    logger.info(f"--> saving model ...")
    save_dir = os.path.join(cfg.output_dir, checkpoint_name)
    os.makedirs(save_dir, exist_ok=True)
    # save_full_path = os.path.join(save_dir, "model.pt")
    save_full_path = save_dir
    model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True)
    logger.info(f"encoder saved at {save_full_path}")
      
def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True):
    logger.info(f"--> saving model ...")
    save_dir = os.path.join(cfg.output_dir, checkpoint_name)
    os.makedirs(save_dir, exist_ok=True)
    save_full_path = os.path.join(save_dir, "model.pt")
    if cfg.enable_ddp:
        model = model.module
    cpu_state = model.state_dict()
    if save_trainable_only:
        state_dict = OrderedDict()
        for name, para in model.named_parameters():
            if para.requires_grad:
                state_dict[name] = cpu_state[name]
    else:
        state_dict = cpu_state
    torch.save(state_dict, save_full_path)
    logger.info(f"encoder saved at {save_full_path}")
    
def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0):
    with FSDP.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
    ):
        cpu_state = model.state_dict()
        logger.info(f"saving process: rank {rank}  done w model state_dict\n")

    if rank == 0:
        logger.info(f"--> saving model ...")
        save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1))
        os.makedirs(save_dir, exist_ok=True)

        if not cfg.freeze_llm:
            llm_dict = {}
            for key in cpu_state.keys():
                if key.startswith("llm."):
                    llm_dict[key] = cpu_state[key]
            model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict)
            logger.info(f"llm saved at {save_dir}")

        save_full_path = os.path.join(save_dir, "model.pt")
        encoder_dict = {}
        if not cfg.freeze_encoder:
            for key in cpu_state.keys():
                if key.startswith("encoder."):
                    encoder_dict[key] = cpu_state[key]
        for key in cpu_state.keys():
            if key.startswith("encoder_projector."):
                encoder_dict[key] = cpu_state[key]
        torch.save(encoder_dict, save_full_path)
        logger.info(f"encoder saved at {save_full_path}")

        logger.info(f"model checkpoint saved for epoch {epoch+1}\n")
        
    dist.barrier()

def load_model_checkpoint(model, rank, cfg):
    """load local checkpoint to rank0 cpu

    must be called * before * passing to FSDP"""

    if rank != 0:
        return

    # where is the checkpoint at...
    full_state_dict_model_path = (
        Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
    )
    # is it present...
    if not full_state_dict_model_path.is_file():
        logger.info(
            f"model checkpoint {full_state_dict_model_path} not present. Returning..."
        )
        return


    model_checkpoint = torch.load(full_state_dict_model_path)
    # integrate into loaded model
    model.load_state_dict(model_checkpoint)

    
    logger.info(f"model checkpoint loaded to rank0 cpu")


def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
    """save optimizer state via full state dict"""

   
    logger.info(f"--> optim state call on rank {rank}\n")

    # pull all sharded optimizer states to rank0 cpu...

    optim_state = FSDP.full_optim_state_dict(model, optimizer)

    
    logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")

    if rank == 0:
        folder_name = (
        cfg.dist_checkpoint_root_folder
        + "/"
        + cfg.dist_checkpoint_folder
        + "-"
        + cfg.model_name
        )
        save_dir = Path.cwd() / folder_name
        save_dir.mkdir(parents=True, exist_ok=True)

        opt_save_name = (
            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
        )
        opt_save_full_path = save_dir / opt_save_name

        logger.info(f"--> saving optimizer state...")

        torch.save(optim_state, opt_save_full_path)

        logger.info(f"--> saved {opt_save_full_path} to disk")


def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
    """load an fsdp optimizer full_state checkpoint using scatter method

    this ensures only rank 0 loads the optimizer state dict and scatters to other ranks

    """


    if not optimizer_checkpoint_path.is_file():
        logger.info(
            f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
        )
        return

    full_osd = None

    if rank == 0:
        full_osd = torch.load(optimizer_checkpoint_path)

    # called from all ranks, though only rank0 has a valid param for full_osd
    sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)

    logger.info(f"optimizer shard loaded on rank {rank}")

def load_sharded_model_single_gpu(model,model_path):
    
    reader = FileSystemReader(model_path)
    
    state_dict = {
        "model": model.state_dict()
    }
    
    dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader= FileSystemReader(model_path),
                no_dist=True,
            )
    
    model.load_state_dict(state_dict["model"])
    
    logger.info(f"Sharded state checkpoint loaded from {model_path}")
    return model