|
|
|
|
|
|
|
import contextlib |
|
import os |
|
import tempfile |
|
from pathlib import Path |
|
|
|
import torch |
|
from composer.core import Callback, State |
|
from composer.core.state import (fsdp_get_optim_state_dict, |
|
fsdp_state_dict_type_context) |
|
from composer.loggers import Logger |
|
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader |
|
from composer.utils import (dist, format_name_with_dist_and_time, parse_uri, |
|
reproducibility) |
|
|
|
|
|
class MonolithicCheckpointSaver(Callback): |
|
"""Save a monolithic checkpoint every N batches. |
|
|
|
Args: |
|
save_folder (str): Folder to save checkpoints to (can be a URI) |
|
filename (str): Filename to save checkpoints to. |
|
batch_interval (int): Number of batches between checkpoints. |
|
overwrite (bool): Whether to overwrite previous checkpoints. |
|
keep_optimizer(bool): Whether to save the optimizer state in the monolithic checkpoint. |
|
""" |
|
|
|
def __init__(self, |
|
save_folder: str, |
|
batch_interval: int, |
|
filename: str = 'ep{epoch}-ba{batch}.pt', |
|
overwrite: bool = False, |
|
keep_optimizers: bool = False): |
|
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( |
|
save_folder) |
|
self.filename_format_str = filename |
|
self.batch_interval = batch_interval |
|
self.upload_to_object_store = (self.backend != '') |
|
self.overwrite = overwrite |
|
self.keep_optimizers = keep_optimizers |
|
if self.upload_to_object_store: |
|
self.remote_ud = RemoteUploaderDownloader( |
|
bucket_uri=f'{self.backend}://{self.bucket_name}') |
|
else: |
|
self.remote_ud = None |
|
|
|
def init(self, state: State, logger: Logger) -> None: |
|
if self.upload_to_object_store and self.remote_ud is not None: |
|
self.remote_ud.init(state, logger) |
|
|
|
|
|
state.callbacks.append(self.remote_ud) |
|
|
|
def batch_checkpoint(self, state: State, logger: Logger) -> None: |
|
if state.timestamp.batch.value % self.batch_interval == 0: |
|
self._save_checkpoint(state, logger) |
|
|
|
def fit_end(self, state: State, logger: Logger) -> None: |
|
if state.timestamp.batch.value % self.batch_interval != 0: |
|
self._save_checkpoint(state, logger) |
|
|
|
def _save_checkpoint(self, state: State, logger: Logger) -> None: |
|
del logger |
|
|
|
filename = format_name_with_dist_and_time(self.filename_format_str, |
|
state.run_name, |
|
state.timestamp) |
|
save_dir = format_name_with_dist_and_time(self.save_dir_format_str, |
|
state.run_name, |
|
state.timestamp) |
|
dir_context_mgr = tempfile.TemporaryDirectory( |
|
) if self.upload_to_object_store else contextlib.nullcontext( |
|
enter_result=save_dir) |
|
with dir_context_mgr as temp_save_dir: |
|
|
|
assert isinstance(temp_save_dir, str) |
|
|
|
save_path = str(Path(temp_save_dir) / Path(filename)) |
|
dirname = os.path.dirname(save_path) |
|
if dirname: |
|
os.makedirs(dirname, exist_ok=True) |
|
state_dict = { |
|
'state': state.state_dict(), |
|
'rng': reproducibility.get_rng_state() |
|
} |
|
|
|
state_dict['state'].pop('optimizers') |
|
state_dict['state'].pop('model') |
|
|
|
|
|
with fsdp_state_dict_type_context(state.model, |
|
state_dict_type='full'): |
|
state_dict['state']['model'] = state.model.state_dict() |
|
|
|
|
|
if self.keep_optimizers: |
|
optimizer = state.optimizers[0] |
|
state_dict['state']['optimizers'] = { |
|
type(optimizer).__qualname__: |
|
fsdp_get_optim_state_dict(state.model, |
|
optimizer, |
|
state_dict_type='full') |
|
} |
|
if dist.get_global_rank() == 0: |
|
torch.save(state_dict, save_path) |
|
|
|
if self.upload_to_object_store and self.remote_ud is not None and dist.get_global_rank( |
|
) == 0: |
|
remote_file_name = str(Path(save_dir) / Path(filename)) |
|
self.remote_ud.upload_file(state=state, |
|
remote_file_name=remote_file_name, |
|
file_path=Path(save_path), |
|
overwrite=self.overwrite) |
|
|