crystal-technologies's picture
Upload 303 files
de4ade4
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility
log = logging.getLogger(__name__)
class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Args:
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
"""
def __init__(
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {
'float32': torch.float32,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
if self.upload_to_object_store:
self.remote_ud = RemoteUploaderDownloader(
bucket_uri=f'{self.backend}://{self.bucket_name}',
num_concurrent_uploads=4)
else:
self.remote_ud = None
self.last_checkpoint_batch: Optional[Time] = None
def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
if state.get_elapsed_duration() is not None and self.check_interval(
state,
event) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
+ f'Got {type(state.model)} instead.')
if self.upload_to_object_store and self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)
def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused
self.last_checkpoint_batch = state.timestamp.batch
log.info('Saving HuggingFace formatted checkpoint')
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
assert isinstance(state.model, HuggingFaceModel)
save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result
with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)
if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
edit_files_for_hf_compatibility(temp_save_dir)
with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)
if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'
edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)
if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
for filename in os.listdir(temp_save_dir):
self.remote_ud.upload_file(
state=state,
remote_file_name=os.path.join(save_dir, filename),
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
)
dist.barrier()