|
|
|
import logging |
|
import os |
|
import pickle |
|
import torch |
|
|
|
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer, _IncompatibleKeys |
|
from fvcore.common.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
import uniperceiver.utils.comm as comm |
|
from uniperceiver.utils.env import TORCH_VERSION |
|
from uniperceiver.utils.file_io import PathManager |
|
from collections import defaultdict |
|
import copy |
|
import io |
|
|
|
from .c2_model_loading import align_and_update_state_dicts |
|
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple |
|
|
|
|
|
|
|
import shutil |
|
from timm.utils import ModelEma |
|
|
|
|
|
class PeriodicEpochCheckpointer(PeriodicCheckpointer): |
|
def step(self, iteration: int, epoch: int, **kwargs: Any) -> None: |
|
""" |
|
Perform the appropriate action at the given iteration. |
|
|
|
Args: |
|
iteration (int): the current iteration, ranged in [0, max_iter-1]. |
|
kwargs (Any): extra data to save, same as in |
|
:meth:`Checkpointer.save`. |
|
""" |
|
iteration = int(iteration) |
|
epoch = int(epoch) |
|
additional_state = {"iteration": iteration} |
|
additional_state.update(kwargs) |
|
|
|
if (iteration + 1) % self.period == 0: |
|
|
|
self.checkpointer.save( |
|
"{}_Epoch_{:05d}_Iter_{:07d}".format(self.file_prefix, epoch, |
|
iteration), |
|
**additional_state) |
|
|
|
if self.max_to_keep is not None: |
|
self.recent_checkpoints.append( |
|
self.checkpointer.get_checkpoint_file()) |
|
|
|
|
|
if len(self.recent_checkpoints) > self.max_to_keep: |
|
file_to_delete = self.recent_checkpoints.pop(0) |
|
if self.path_manager.exists( |
|
file_to_delete) and not file_to_delete.endswith( |
|
f"{self.file_prefix}_final.pth"): |
|
self.path_manager.rm(file_to_delete) |
|
|
|
|
|
|
|
class TorchCheckpointer(Checkpointer): |
|
""" |
|
Same as :class:`Checkpointer`, but is able to handle models in uniperceiver |
|
model zoo, and apply conversions for legacy models. |
|
""" |
|
def __init__( |
|
self, |
|
model, |
|
model_ema: ModelEma, |
|
save_dir="", |
|
*, |
|
save_to_disk=None, |
|
checkpoint_mapping=None, |
|
mapping=False, |
|
resume_tau=True, |
|
ceph_save=False, |
|
ceph_config=None, |
|
**checkpointables, |
|
): |
|
is_main_process = comm.is_main_process() |
|
super().__init__( |
|
model, |
|
save_dir, |
|
save_to_disk=is_main_process |
|
if save_to_disk is None else save_to_disk, |
|
**checkpointables, |
|
) |
|
self.path_manager = PathManager |
|
|
|
if checkpoint_mapping is None: |
|
self.checkpoint_mapping = None |
|
else: |
|
self.checkpoint_mapping = defaultdict(list) |
|
for mapping_pair in checkpoint_mapping: |
|
self.checkpoint_mapping[mapping_pair['ORIGIN']].append( |
|
mapping_pair['DEST']) |
|
self.mapping = mapping |
|
self.resume_tau = resume_tau |
|
self.ceph_save = ceph_save |
|
if self.ceph_save: |
|
self.path_prefix = 's3://checkpoints_zjg/' |
|
self.client = PetrelBackend(path_mapping={}, |
|
tcs_conf_path=ceph_config) |
|
|
|
|
|
|
|
|
|
|
|
def _load_file(self, filename): |
|
if filename.endswith(".pkl"): |
|
with PathManager.open(filename, "rb") as f: |
|
data = pickle.load(f, encoding="latin1") |
|
if "model" in data and "__author__" in data: |
|
|
|
self.logger.info("Reading a file from '{}'".format( |
|
data["__author__"])) |
|
return data |
|
else: |
|
|
|
if "blobs" in data: |
|
|
|
data = data["blobs"] |
|
data = { |
|
k: v |
|
for k, v in data.items() if not k.endswith("_momentum") |
|
} |
|
return { |
|
"model": data, |
|
"__author__": "Caffe2", |
|
"matching_heuristics": True |
|
} |
|
if self.ceph_save: |
|
relpath = os.path.relpath(filename, os.getcwd()) |
|
s3url = os.path.join(self.path_prefix, relpath) |
|
with io.BytesIO(self.client.get(s3url)) as buffer: |
|
loaded = torch.load(buffer, map_location=torch.device("cpu")) |
|
else: |
|
loaded = super()._load_file(filename) |
|
if "model" not in loaded: |
|
loaded = {"model": loaded} |
|
return loaded |
|
|
|
def save(self, name: str, **kwargs: Any) -> None: |
|
""" |
|
Dump model and checkpointables to a file. |
|
|
|
Args: |
|
name (str): name of the file. |
|
kwargs (dict): extra arbitrary data to save. |
|
""" |
|
if not self.save_dir or not self.save_to_disk: |
|
return |
|
|
|
data = {} |
|
data["model"] = self.model.state_dict() |
|
for key, obj in self.checkpointables.items(): |
|
data[key] = obj.state_dict() |
|
data.update(kwargs) |
|
|
|
basename = "{}.pth".format(name) |
|
|
|
if self.ceph_save: |
|
local_save_file = os.path.join(self.save_dir, basename) |
|
relpath = os.path.relpath(local_save_file, os.getcwd()) |
|
save_file = os.path.join(self.path_prefix, relpath) |
|
assert os.path.basename(save_file) == basename, basename |
|
self.logger.info("Saving checkpoint to {}".format(save_file)) |
|
with io.BytesIO() as f: |
|
torch.save(data, f) |
|
self.client.put(f.getvalue(), save_file) |
|
else: |
|
save_file = os.path.join(self.save_dir, basename) |
|
assert os.path.basename(save_file) == basename, basename |
|
self.logger.info("Saving checkpoint to {}".format(save_file)) |
|
with self.path_manager.open(save_file, "wb") as f: |
|
torch.save(data, f) |
|
self.tag_last_checkpoint(basename) |
|
|
|
def load(self, |
|
path: str, |
|
checkpointables: Optional[List[str]] = None) -> Dict[str, Any]: |
|
""" |
|
Load from the given checkpoint. |
|
|
|
Args: |
|
path (str): path or url to the checkpoint. If empty, will not load |
|
anything. |
|
checkpointables (list): List of checkpointable names to load. If not |
|
specified (None), will load all the possible checkpointables. |
|
Returns: |
|
dict: |
|
extra data loaded from the checkpoint that has not been |
|
processed. For example, those saved with |
|
:meth:`.save(**extra_data)`. |
|
""" |
|
if not path: |
|
|
|
self.logger.info( |
|
"No checkpoint found. Initializing model from scratch") |
|
return {} |
|
self.logger.info("[Checkpointer] Loading from {} ...".format(path)) |
|
if not self.ceph_save: |
|
if not os.path.isfile(path): |
|
path = self.path_manager.get_local_path(path) |
|
assert os.path.isfile(path), "Checkpoint {} not found!".format( |
|
path) |
|
else: |
|
relpath = os.path.relpath(path, os.getcwd()) |
|
s3url = os.path.join(self.path_prefix, relpath) |
|
|
|
|
|
|
|
|
|
checkpoint = self._load_file(path) |
|
incompatible = self._load_model(checkpoint) |
|
if (incompatible is not None |
|
): |
|
self._log_incompatible_keys(incompatible) |
|
|
|
for key in self.checkpointables if checkpointables is None else checkpointables: |
|
if key in checkpoint: |
|
self.logger.info("Loading {} from {} ...".format(key, path)) |
|
obj = self.checkpointables[key] |
|
obj.load_state_dict(checkpoint.pop(key)) |
|
|
|
|
|
return checkpoint |
|
|
|
def _convert_checkpoint(self, checkpoint): |
|
|
|
if self.checkpoint_mapping is not None and self.mapping: |
|
pretrain_checkpoint = checkpoint["model"] |
|
for origin_task in self.checkpoint_mapping.keys(): |
|
for k in list(pretrain_checkpoint.keys()): |
|
if origin_task in k: |
|
|
|
state_dict_temp = copy.deepcopy( |
|
pretrain_checkpoint.pop(k)) |
|
for subtask in self.checkpoint_mapping[origin_task]: |
|
new_key = k.replace(origin_task, subtask) |
|
pretrain_checkpoint[new_key] = state_dict_temp |
|
checkpoint["model"] = pretrain_checkpoint |
|
|
|
if not self.resume_tau: |
|
pretrain_checkpoint = checkpoint["model"] |
|
for k in list(pretrain_checkpoint.keys()): |
|
if "logit_scale" in k: |
|
pretrain_checkpoint.pop(k) |
|
checkpoint["model"] = pretrain_checkpoint |
|
return checkpoint |
|
|
|
def _load_model(self, checkpoint): |
|
if checkpoint.get("matching_heuristics", False): |
|
self._convert_ndarray_to_tensor(checkpoint["model"]) |
|
|
|
model_state_dict = self.model.state_dict() |
|
align_and_update_state_dicts( |
|
model_state_dict, |
|
checkpoint["model"], |
|
c2_conversion=checkpoint.get("__author__", None) == "Caffe2", |
|
) |
|
checkpoint["model"] = model_state_dict |
|
|
|
|
|
checkpoint = self._convert_checkpoint(checkpoint) |
|
|
|
|
|
incompatible = super()._load_model(checkpoint) |
|
if incompatible is None: |
|
return None |
|
|
|
model_buffers = dict(self.model.named_buffers(recurse=False)) |
|
for k in ["pixel_mean", "pixel_std"]: |
|
|
|
|
|
|
|
if k in model_buffers: |
|
try: |
|
incompatible.missing_keys.remove(k) |
|
except ValueError: |
|
pass |
|
return incompatible |
|
|
|
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None: |
|
""" |
|
Log information about the incompatible keys returned by ``_load_model``. |
|
""" |
|
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes: |
|
self.logger.warning( |
|
"Skip loading parameter '{}' to the model due to incompatible " |
|
"shapes: {} in the checkpoint but {} in the " |
|
"model! You might want to double check if this is expected.". |
|
format(k, shape_checkpoint, shape_model)) |
|
if incompatible.missing_keys: |
|
self.logger.info( |
|
get_missing_parameters_message(incompatible.missing_keys)) |
|
if incompatible.unexpected_keys: |
|
self.logger.info( |
|
get_unexpected_parameters_message( |
|
incompatible.unexpected_keys)) |
|
|
|
def resume_or_load(self, path, resume: bool = True, **kwargs): |
|
super().resume_or_load(path, resume=resume) |
|
|