Spaces:
Paused
Paused
File size: 1,271 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
# Copyright (c) Facebook, Inc. and its affiliates.
from collections import OrderedDict
from detectron2.checkpoint import DetectionCheckpointer
def _rename_HRNet_weights(weights):
# We detect and rename HRNet weights for DensePose. 1956 and 1716 are values that are
# common to all HRNet pretrained weights, and should be enough to accurately identify them
if (
len(weights["model"].keys()) == 1956
and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716
):
hrnet_weights = OrderedDict()
for k in weights["model"].keys():
hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k]
return {"model": hrnet_weights}
else:
return weights
class DensePoseCheckpointer(DetectionCheckpointer):
"""
Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights
"""
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
def _load_file(self, filename: str) -> object:
"""
Adding hrnet support
"""
weights = super()._load_file(filename)
return _rename_HRNet_weights(weights)
|