File size: 9,710 Bytes
f57bf3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Facilities for pickling Python code alongside other data.
The pickled code is automatically imported into a separate Python module
during unpickling. This way, any previously exported pickles will remain
usable even if the original code is no longer available, or if the current
version of the code is not consistent with what was originally pickled."""
import sys
import pickle
import io
import inspect
import copy
import uuid
import types
import dnnlib
#----------------------------------------------------------------------------
_version = 6 # internal version number
_decorators = set() # {decorator_class, ...}
_import_hooks = [] # [hook_function, ...]
_module_to_src_dict = dict() # {module: src, ...}
_src_to_module_dict = dict() # {src: module, ...}
#----------------------------------------------------------------------------
def persistent_class(orig_class):
r"""Class decorator that extends a given class to save its source code
when pickled.
Example:
from torch_utils import persistence
@persistence.persistent_class
class MyNetwork(torch.nn.Module):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.fc = MyLayer(num_inputs, num_outputs)
...
@persistence.persistent_class
class MyLayer(torch.nn.Module):
...
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
source code alongside other internal state (e.g., parameters, buffers,
and submodules). This way, any previously exported pickle will remain
usable even if the class definitions have been modified or are no
longer available.
The decorator saves the source code of the entire Python module
containing the decorated class. It does *not* save the source code of
any imported modules. Thus, the imported modules must be available
during unpickling, also including `torch_utils.persistence` itself.
It is ok to call functions defined in the same module from the
decorated class. However, if the decorated class depends on other
classes defined in the same module, they must be decorated as well.
This is illustrated in the above example in the case of `MyLayer`.
It is also possible to employ the decorator just-in-time before
calling the constructor. For example:
cls = MyLayer
if want_to_make_it_persistent:
cls = persistence.persistent_class(cls)
layer = cls(num_inputs, num_outputs)
As an additional feature, the decorator also keeps track of the
arguments that were used to construct each instance of the decorated
class. The arguments can be queried via `obj.init_args` and
`obj.init_kwargs`, and they are automatically pickled alongside other
object state. A typical use case is to first unpickle a previous
instance of a persistent class, and then upgrade it to use the latest
version of the source code:
with open('old_pickle.pkl', 'rb') as f:
old_net = pickle.load(f)
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
"""
assert isinstance(orig_class, type)
if is_persistent(orig_class):
return orig_class
assert orig_class.__module__ in sys.modules
orig_module = sys.modules[orig_class.__module__]
orig_module_src = _module_to_src(orig_module)
class Decorator(orig_class):
_orig_module_src = orig_module_src
_orig_class_name = orig_class.__name__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._init_args = copy.deepcopy(args)
self._init_kwargs = copy.deepcopy(kwargs)
assert orig_class.__name__ in orig_module.__dict__
_check_pickleable(self.__reduce__())
@property
def init_args(self):
return copy.deepcopy(self._init_args)
@property
def init_kwargs(self):
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
def __reduce__(self):
fields = list(super().__reduce__())
fields += [None] * max(3 - len(fields), 0)
if fields[0] is not _reconstruct_persistent_obj:
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
fields[0] = _reconstruct_persistent_obj # reconstruct func
fields[1] = (meta,) # reconstruct args
fields[2] = None # state dict
return tuple(fields)
Decorator.__name__ = orig_class.__name__
_decorators.add(Decorator)
return Decorator
#----------------------------------------------------------------------------
def is_persistent(obj):
r"""Test whether the given object or class is persistent, i.e.,
whether it will save its source code when pickled.
"""
try:
if obj in _decorators:
return True
except TypeError:
pass
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
#----------------------------------------------------------------------------
def import_hook(hook):
r"""Register an import hook that is called whenever a persistent object
is being unpickled. A typical use case is to patch the pickled source
code to avoid errors and inconsistencies when the API of some imported
module has changed.
The hook should have the following signature:
hook(meta) -> modified meta
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
type: Type of the persistent object, e.g. `'class'`.
version: Internal version number of `torch_utils.persistence`.
module_src Original source code of the Python module.
class_name: Class name in the original Python module.
state: Internal state of the object.
Example:
@persistence.import_hook
def wreck_my_network(meta):
if meta.class_name == 'MyNetwork':
print('MyNetwork is being imported. I will wreck it!')
meta.module_src = meta.module_src.replace("True", "False")
return meta
"""
assert callable(hook)
_import_hooks.append(hook)
#----------------------------------------------------------------------------
def _reconstruct_persistent_obj(meta):
r"""Hook that is called internally by the `pickle` module to unpickle
a persistent object.
"""
meta = dnnlib.EasyDict(meta)
meta.state = dnnlib.EasyDict(meta.state)
for hook in _import_hooks:
meta = hook(meta)
assert meta is not None
assert meta.version == _version
module = _src_to_module(meta.module_src)
assert meta.type == 'class'
orig_class = module.__dict__[meta.class_name]
decorator_class = persistent_class(orig_class)
obj = decorator_class.__new__(decorator_class)
setstate = getattr(obj, '__setstate__', None)
if callable(setstate):
setstate(meta.state) # pylint: disable=not-callable
else:
obj.__dict__.update(meta.state)
return obj
#----------------------------------------------------------------------------
def _module_to_src(module):
r"""Query the source code of a given Python module.
"""
src = _module_to_src_dict.get(module, None)
if src is None:
src = inspect.getsource(module)
_module_to_src_dict[module] = src
_src_to_module_dict[src] = module
return src
def _src_to_module(src):
r"""Get or create a Python module for the given source code.
"""
module = _src_to_module_dict.get(src, None)
if module is None:
module_name = "_imported_module_" + uuid.uuid4().hex
module = types.ModuleType(module_name)
sys.modules[module_name] = module
_module_to_src_dict[module] = src
_src_to_module_dict[src] = module
exec(src, module.__dict__) # pylint: disable=exec-used
return module
#----------------------------------------------------------------------------
def _check_pickleable(obj):
r"""Check that the given object is pickleable, raising an exception if
it is not. This function is expected to be considerably more efficient
than actually pickling the object.
"""
def recurse(obj):
if isinstance(obj, (list, tuple, set)):
return [recurse(x) for x in obj]
if isinstance(obj, dict):
return [[recurse(x), recurse(y)] for x, y in obj.items()]
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
return None # Python primitive types are pickleable.
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
return None # NumPy arrays and PyTorch tensors are pickleable.
if is_persistent(obj):
return None # Persistent objects are pickleable, by virtue of the constructor check.
return obj
with io.BytesIO() as f:
pickle.dump(recurse(obj), f)
#----------------------------------------------------------------------------
|