File size: 5,868 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import functools
import inspect
import types
from typing import Any
import json_tricks
from .utils import get_importable_name, get_module_name, import_, reset_uid
def get_init_parameters_or_fail(obj, silently=False):
if hasattr(obj, '_init_parameters'):
return obj._init_parameters
elif silently:
return None
else:
raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.')
### This is a patch of json-tricks to make it more useful to us ###
def _serialize_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
try: # FIXME: raise error
if hasattr(obj, '__class__'):
return {
'__type__': get_importable_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj)
}
except ValueError:
pass
return obj
def _serialize_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_importable_name(obj, relocate_module=True)}
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return {'__typename__': get_importable_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
### End of json-tricks patch ###
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False):
class wrapper(cls):
def __init__(self, *args, **kwargs):
if reset_mutation_uid:
reset_uid('mutation')
if store_init_parameters:
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
self._init_parameters = full_args
else:
self._init_parameters = {}
super().__init__(*args, **kwargs)
wrapper.__module__ = get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def serialize_cls(cls):
"""
To create an serializable class.
"""
return _create_wrapper_cls(cls)
def transparent_serialize(cls):
"""
Wrap a module but does not record parameters. For internal use only.
"""
return _create_wrapper_cls(cls, store_init_parameters=False)
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
return serialize_cls(cls)(*args, **kwargs)
def basic_unit(cls):
"""
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
"""
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls)
def model_wrapper(cls):
"""
Wrap the model if you are using pure-python execution engine.
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
return _create_wrapper_cls(cls, reset_mutation_uid=True)
|