File size: 7,511 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility
log = logging.getLogger(__name__)
class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Args:
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
"""
def __init__(
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {
'float32': torch.float32,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
if self.upload_to_object_store:
self.remote_ud = RemoteUploaderDownloader(
bucket_uri=f'{self.backend}://{self.bucket_name}',
num_concurrent_uploads=4)
else:
self.remote_ud = None
self.last_checkpoint_batch: Optional[Time] = None
def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
if state.get_elapsed_duration() is not None and self.check_interval(
state,
event) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
+ f'Got {type(state.model)} instead.')
if self.upload_to_object_store and self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)
def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused
self.last_checkpoint_batch = state.timestamp.batch
log.info('Saving HuggingFace formatted checkpoint')
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
assert isinstance(state.model, HuggingFaceModel)
save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result
with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)
if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
edit_files_for_hf_compatibility(temp_save_dir)
with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)
if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'
edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)
if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
for filename in os.listdir(temp_save_dir):
self.remote_ud.upload_file(
state=state,
remote_file_name=os.path.join(save_dir, filename),
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
)
dist.barrier()
|