Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import os | |
import pickle | |
import torch | |
from fvcore.common.checkpoint import Checkpointer | |
from torch.nn.parallel import DistributedDataParallel | |
import detectron2.utils.comm as comm | |
from detectron2.utils.file_io import PathManager | |
from .c2_model_loading import align_and_update_state_dicts | |
class DetectionCheckpointer(Checkpointer): | |
""" | |
Same as :class:`Checkpointer`, but is able to: | |
1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. | |
2. correctly load checkpoints that are only available on the master worker | |
""" | |
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): | |
is_main_process = comm.is_main_process() | |
super().__init__( | |
model, | |
save_dir, | |
save_to_disk=is_main_process if save_to_disk is None else save_to_disk, | |
**checkpointables, | |
) | |
self.path_manager = PathManager | |
def load(self, path, *args, **kwargs): | |
need_sync = False | |
if path and isinstance(self.model, DistributedDataParallel): | |
logger = logging.getLogger(__name__) | |
path = self.path_manager.get_local_path(path) | |
has_file = os.path.isfile(path) | |
all_has_file = comm.all_gather(has_file) | |
if not all_has_file[0]: | |
raise OSError(f"File {path} not found on main worker.") | |
if not all(all_has_file): | |
logger.warning( | |
f"Not all workers can read checkpoint {path}. " | |
"Training may fail to fully resume." | |
) | |
# TODO: broadcast the checkpoint file contents from main | |
# worker, and load from it instead. | |
need_sync = True | |
if not has_file: | |
path = None # don't load if not readable | |
ret = super().load(path, *args, **kwargs) | |
if need_sync: | |
logger.info("Broadcasting model states from main worker ...") | |
self.model._sync_params_and_buffers() | |
return ret | |
def _load_file(self, filename): | |
if filename.endswith(".pkl"): | |
with PathManager.open(filename, "rb") as f: | |
data = pickle.load(f, encoding="latin1") | |
if "model" in data and "__author__" in data: | |
# file is in Detectron2 model zoo format | |
self.logger.info("Reading a file from '{}'".format(data["__author__"])) | |
return data | |
else: | |
# assume file is from Caffe2 / Detectron1 model zoo | |
if "blobs" in data: | |
# Detection models have "blobs", but ImageNet models don't | |
data = data["blobs"] | |
data = {k: v for k, v in data.items() if not k.endswith("_momentum")} | |
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} | |
elif filename.endswith(".pyth"): | |
# assume file is from pycls; no one else seems to use the ".pyth" extension | |
with PathManager.open(filename, "rb") as f: | |
data = torch.load(f) | |
assert ( | |
"model_state" in data | |
), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." | |
model_state = { | |
k: v | |
for k, v in data["model_state"].items() | |
if not k.endswith("num_batches_tracked") | |
} | |
return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} | |
loaded = super()._load_file(filename) # load native pth checkpoint | |
if "model" not in loaded: | |
loaded = {"model": loaded} | |
return loaded | |
def _load_model(self, checkpoint): | |
if checkpoint.get("matching_heuristics", False): | |
self._convert_ndarray_to_tensor(checkpoint["model"]) | |
# convert weights by name-matching heuristics | |
checkpoint["model"] = align_and_update_state_dicts( | |
self.model.state_dict(), | |
checkpoint["model"], | |
c2_conversion=checkpoint.get("__author__", None) == "Caffe2", | |
) | |
# for non-caffe2 models, use standard ways to load it | |
incompatible = super()._load_model(checkpoint) | |
model_buffers = dict(self.model.named_buffers(recurse=False)) | |
for k in ["pixel_mean", "pixel_std"]: | |
# Ignore missing key message about pixel_mean/std. | |
# Though they may be missing in old checkpoints, they will be correctly | |
# initialized from config anyway. | |
if k in model_buffers: | |
try: | |
incompatible.missing_keys.remove(k) | |
except ValueError: | |
pass | |
for k in incompatible.unexpected_keys[:]: | |
# Ignore unexpected keys about cell anchors. They exist in old checkpoints | |
# but now they are non-persistent buffers and will not be in new checkpoints. | |
if "anchor_generator.cell_anchors" in k: | |
incompatible.unexpected_keys.remove(k) | |
return incompatible | |