File size: 7,511 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
    edit_files_for_hf_compatibility

log = logging.getLogger(__name__)


class HuggingFaceCheckpointer(Callback):
    """Save a huggingface formatted checkpoint during training.

    Args:
        save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
            this would be the same as your save_folder.
        save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
            saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
            Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
            :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
        huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
        precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
        overwrite (bool): Whether to overwrite previous checkpoints.
    """

    def __init__(
        self,
        save_folder: str,
        save_interval: Union[str, int, Time],
        huggingface_folder_name: str = 'ba{batch}',
        precision: str = 'float32',
        overwrite: bool = False,
    ):
        self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
            save_folder)
        self.overwrite = overwrite
        self.precision = precision
        self.dtype = {
            'float32': torch.float32,
            'float16': torch.float16,
            'bfloat16': torch.bfloat16,
        }[precision]
        self.huggingface_folder_name_fstr = os.path.join(
            'huggingface', huggingface_folder_name)
        self.check_interval = create_interval_scheduler(
            save_interval, include_end_of_training=True)
        self.upload_to_object_store = (self.backend != '')
        if self.upload_to_object_store:
            self.remote_ud = RemoteUploaderDownloader(
                bucket_uri=f'{self.backend}://{self.bucket_name}',
                num_concurrent_uploads=4)
        else:
            self.remote_ud = None

        self.last_checkpoint_batch: Optional[Time] = None

    def run_event(self, event: Event, state: State, logger: Logger) -> None:
        # The interval scheduler handles only returning True for the appropriate events
        if state.get_elapsed_duration() is not None and self.check_interval(
                state,
                event) and self.last_checkpoint_batch != state.timestamp.batch:
            self._save_checkpoint(state, logger)
        elif event == Event.INIT:
            if not isinstance(state.model, HuggingFaceModel):
                raise ValueError(
                    f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
                    + f'Got {type(state.model)} instead.')
            if self.upload_to_object_store and self.remote_ud is not None:
                self.remote_ud.init(state, logger)
                state.callbacks.append(self.remote_ud)

    def _save_checkpoint(self, state: State, logger: Logger):
        del logger  # unused

        self.last_checkpoint_batch = state.timestamp.batch

        log.info('Saving HuggingFace formatted checkpoint')

        from transformers.models.auto.configuration_auto import CONFIG_MAPPING
        CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
        MPTConfig.register_for_auto_class()
        MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

        assert isinstance(state.model, HuggingFaceModel)

        save_dir = format_name_with_dist_and_time(
            str(
                Path(self.save_dir_format_str) /
                self.huggingface_folder_name_fstr), state.run_name,
            state.timestamp)
        dir_context_mgr = tempfile.TemporaryDirectory(
        ) if self.upload_to_object_store else contextlib.nullcontext(
            enter_result=save_dir)

        with dir_context_mgr as temp_save_dir:
            assert isinstance(temp_save_dir,
                              str)  # pyright doesn't know about enter_result

            with fsdp_state_dict_type_context(state.model.model,
                                              state_dict_type='full'):
                state_dict = state.model.model.state_dict()

                # convert the state dict to the requested precision
                for k, v in state_dict.items():
                    if isinstance(v, torch.Tensor):
                        state_dict[k] = v.to(dtype=self.dtype)

            if dist.get_global_rank() == 0:
                # We raise above if the model is not a HuggingFaceModel, so this assert is safe
                assert hasattr(state.model.model, 'save_pretrained')
                state.model.model.save_pretrained(temp_save_dir,
                                                  state_dict=state_dict)

                if state.model.tokenizer is not None:
                    assert isinstance(state.model.tokenizer,
                                      PreTrainedTokenizerBase)
                    state.model.tokenizer.save_pretrained(temp_save_dir)

                # Only need to edit files for MPT because it has custom code
                if state.model.model.config.model_type == 'mpt':
                    edit_files_for_hf_compatibility(temp_save_dir)

                with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
                    edited_config = json.load(f)

                if state.model.model.config.model_type == 'mpt':
                    edited_config['attn_config']['attn_impl'] = 'torch'
                    edited_config['init_device'] = 'cpu'

                edited_config['torch_dtype'] = self.precision
                with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
                    json.dump(edited_config, f, indent=4)

                if self.upload_to_object_store:
                    assert self.remote_ud is not None
                    # TODO change to log after other pr
                    log.info(
                        f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
                    )
                    for filename in os.listdir(temp_save_dir):
                        self.remote_ud.upload_file(
                            state=state,
                            remote_file_name=os.path.join(save_dir, filename),
                            file_path=Path(os.path.join(temp_save_dir,
                                                        filename)),
                            overwrite=self.overwrite,
                        )

        dist.barrier()