File size: 5,868 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
import functools
import inspect
import types
from typing import Any

import json_tricks

from .utils import get_importable_name, get_module_name, import_, reset_uid


def get_init_parameters_or_fail(obj, silently=False):
    if hasattr(obj, '_init_parameters'):
        return obj._init_parameters
    elif silently:
        return None
    else:
        raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
                         'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
                         'If it is a customized module, please to decorate it with @basic_unit. '
                         'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
                         'try to use serialize or @serialize_cls.')


### This is a patch of json-tricks to make it more useful to us ###


def _serialize_class_instance_encode(obj, primitives=False):
    assert not primitives, 'Encoding with primitives is not supported.'
    try:  # FIXME: raise error
        if hasattr(obj, '__class__'):
            return {
                '__type__': get_importable_name(obj.__class__),
                'arguments': get_init_parameters_or_fail(obj)
            }
    except ValueError:
        pass
    return obj


def _serialize_class_instance_decode(obj):
    if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
        return import_(obj['__type__'])(**obj['arguments'])
    return obj


def _type_encode(obj, primitives=False):
    assert not primitives, 'Encoding with primitives is not supported.'
    if isinstance(obj, type):
        return {'__typename__': get_importable_name(obj, relocate_module=True)}
    if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
        # This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
        # https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
        return {'__typename__': get_importable_name(obj, relocate_module=True)}
    return obj


def _type_decode(obj):
    if isinstance(obj, dict) and '__typename__' in obj:
        return import_(obj['__typename__'])
    return obj


json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])

### End of json-tricks patch ###


class Translatable(abc.ABC):
    """
    Inherit this class and implement ``translate`` when the inner class needs a different
    parameter from the wrapper class in its init function.
    """

    @abc.abstractmethod
    def _translate(self) -> Any:
        pass


def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False):
    class wrapper(cls):
        def __init__(self, *args, **kwargs):
            if reset_mutation_uid:
                reset_uid('mutation')
            if store_init_parameters:
                argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
                full_args = {}
                full_args.update(kwargs)

                assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
                for argname, value in zip(argname_list, args):
                    full_args[argname] = value

                # translate parameters
                args = list(args)
                for i, value in enumerate(args):
                    if isinstance(value, Translatable):
                        args[i] = value._translate()
                for i, value in kwargs.items():
                    if isinstance(value, Translatable):
                        kwargs[i] = value._translate()

                self._init_parameters = full_args
            else:
                self._init_parameters = {}

            super().__init__(*args, **kwargs)

    wrapper.__module__ = get_module_name(cls)
    wrapper.__name__ = cls.__name__
    wrapper.__qualname__ = cls.__qualname__
    wrapper.__init__.__doc__ = cls.__init__.__doc__

    return wrapper


def serialize_cls(cls):
    """
    To create an serializable class.
    """
    return _create_wrapper_cls(cls)


def transparent_serialize(cls):
    """
    Wrap a module but does not record parameters. For internal use only.
    """
    return _create_wrapper_cls(cls, store_init_parameters=False)


def serialize(cls, *args, **kwargs):
    """
    To create an serializable instance inline without decorator. For example,

    .. code-block:: python

        self.op = serialize(MyCustomOp, hidden_units=128)
    """
    return serialize_cls(cls)(*args, **kwargs)


def basic_unit(cls):
    """
    To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
    """
    import torch.nn as nn
    assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
    return serialize_cls(cls)


def model_wrapper(cls):
    """
    Wrap the model if you are using pure-python execution engine.

    The wrapper serves two purposes:

        1. Capture the init parameters of python class so that it can be re-instantiated in another process.
        2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
    """
    return _create_wrapper_cls(cls, reset_mutation_uid=True)