|
|
|
|
|
|
|
import contextlib |
|
import json |
|
import logging |
|
import os |
|
import tempfile |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
import torch |
|
from composer.callbacks.utils import create_interval_scheduler |
|
from composer.core import Callback, Event, State, Time |
|
from composer.core.state import fsdp_state_dict_type_context |
|
from composer.loggers import Logger |
|
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader |
|
from composer.models import HuggingFaceModel |
|
from composer.utils import dist, format_name_with_dist_and_time, parse_uri |
|
from transformers import PreTrainedTokenizerBase |
|
|
|
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM |
|
from llmfoundry.utils.huggingface_hub_utils import \ |
|
edit_files_for_hf_compatibility |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class HuggingFaceCheckpointer(Callback): |
|
"""Save a huggingface formatted checkpoint during training. |
|
|
|
Args: |
|
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that |
|
this would be the same as your save_folder. |
|
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be |
|
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. |
|
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, |
|
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. |
|
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. |
|
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. |
|
overwrite (bool): Whether to overwrite previous checkpoints. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
save_folder: str, |
|
save_interval: Union[str, int, Time], |
|
huggingface_folder_name: str = 'ba{batch}', |
|
precision: str = 'float32', |
|
overwrite: bool = False, |
|
): |
|
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( |
|
save_folder) |
|
self.overwrite = overwrite |
|
self.precision = precision |
|
self.dtype = { |
|
'float32': torch.float32, |
|
'float16': torch.float16, |
|
'bfloat16': torch.bfloat16, |
|
}[precision] |
|
self.huggingface_folder_name_fstr = os.path.join( |
|
'huggingface', huggingface_folder_name) |
|
self.check_interval = create_interval_scheduler( |
|
save_interval, include_end_of_training=True) |
|
self.upload_to_object_store = (self.backend != '') |
|
if self.upload_to_object_store: |
|
self.remote_ud = RemoteUploaderDownloader( |
|
bucket_uri=f'{self.backend}://{self.bucket_name}', |
|
num_concurrent_uploads=4) |
|
else: |
|
self.remote_ud = None |
|
|
|
self.last_checkpoint_batch: Optional[Time] = None |
|
|
|
def run_event(self, event: Event, state: State, logger: Logger) -> None: |
|
|
|
if state.get_elapsed_duration() is not None and self.check_interval( |
|
state, |
|
event) and self.last_checkpoint_batch != state.timestamp.batch: |
|
self._save_checkpoint(state, logger) |
|
elif event == Event.INIT: |
|
if not isinstance(state.model, HuggingFaceModel): |
|
raise ValueError( |
|
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' |
|
+ f'Got {type(state.model)} instead.') |
|
if self.upload_to_object_store and self.remote_ud is not None: |
|
self.remote_ud.init(state, logger) |
|
state.callbacks.append(self.remote_ud) |
|
|
|
def _save_checkpoint(self, state: State, logger: Logger): |
|
del logger |
|
|
|
self.last_checkpoint_batch = state.timestamp.batch |
|
|
|
log.info('Saving HuggingFace formatted checkpoint') |
|
|
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING |
|
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig |
|
MPTConfig.register_for_auto_class() |
|
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') |
|
|
|
assert isinstance(state.model, HuggingFaceModel) |
|
|
|
save_dir = format_name_with_dist_and_time( |
|
str( |
|
Path(self.save_dir_format_str) / |
|
self.huggingface_folder_name_fstr), state.run_name, |
|
state.timestamp) |
|
dir_context_mgr = tempfile.TemporaryDirectory( |
|
) if self.upload_to_object_store else contextlib.nullcontext( |
|
enter_result=save_dir) |
|
|
|
with dir_context_mgr as temp_save_dir: |
|
assert isinstance(temp_save_dir, |
|
str) |
|
|
|
with fsdp_state_dict_type_context(state.model.model, |
|
state_dict_type='full'): |
|
state_dict = state.model.model.state_dict() |
|
|
|
|
|
for k, v in state_dict.items(): |
|
if isinstance(v, torch.Tensor): |
|
state_dict[k] = v.to(dtype=self.dtype) |
|
|
|
if dist.get_global_rank() == 0: |
|
|
|
assert hasattr(state.model.model, 'save_pretrained') |
|
state.model.model.save_pretrained(temp_save_dir, |
|
state_dict=state_dict) |
|
|
|
if state.model.tokenizer is not None: |
|
assert isinstance(state.model.tokenizer, |
|
PreTrainedTokenizerBase) |
|
state.model.tokenizer.save_pretrained(temp_save_dir) |
|
|
|
|
|
if state.model.model.config.model_type == 'mpt': |
|
edit_files_for_hf_compatibility(temp_save_dir) |
|
|
|
with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f: |
|
edited_config = json.load(f) |
|
|
|
if state.model.model.config.model_type == 'mpt': |
|
edited_config['attn_config']['attn_impl'] = 'torch' |
|
edited_config['init_device'] = 'cpu' |
|
|
|
edited_config['torch_dtype'] = self.precision |
|
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f: |
|
json.dump(edited_config, f, indent=4) |
|
|
|
if self.upload_to_object_store: |
|
assert self.remote_ud is not None |
|
|
|
log.info( |
|
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}' |
|
) |
|
for filename in os.listdir(temp_save_dir): |
|
self.remote_ud.upload_file( |
|
state=state, |
|
remote_file_name=os.path.join(save_dir, filename), |
|
file_path=Path(os.path.join(temp_save_dir, |
|
filename)), |
|
overwrite=self.overwrite, |
|
) |
|
|
|
dist.barrier() |
|
|