Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import abc | |
import functools | |
import inspect | |
import types | |
from typing import Any | |
import json_tricks | |
from .utils import get_importable_name, get_module_name, import_, reset_uid | |
def get_init_parameters_or_fail(obj, silently=False): | |
if hasattr(obj, '_init_parameters'): | |
return obj._init_parameters | |
elif silently: | |
return None | |
else: | |
raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. ' | |
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. ' | |
'If it is a customized module, please to decorate it with @basic_unit. ' | |
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), ' | |
'try to use serialize or @serialize_cls.') | |
### This is a patch of json-tricks to make it more useful to us ### | |
def _serialize_class_instance_encode(obj, primitives=False): | |
assert not primitives, 'Encoding with primitives is not supported.' | |
try: # FIXME: raise error | |
if hasattr(obj, '__class__'): | |
return { | |
'__type__': get_importable_name(obj.__class__), | |
'arguments': get_init_parameters_or_fail(obj) | |
} | |
except ValueError: | |
pass | |
return obj | |
def _serialize_class_instance_decode(obj): | |
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj: | |
return import_(obj['__type__'])(**obj['arguments']) | |
return obj | |
def _type_encode(obj, primitives=False): | |
assert not primitives, 'Encoding with primitives is not supported.' | |
if isinstance(obj, type): | |
return {'__typename__': get_importable_name(obj, relocate_module=True)} | |
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)): | |
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized. | |
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function | |
return {'__typename__': get_importable_name(obj, relocate_module=True)} | |
return obj | |
def _type_decode(obj): | |
if isinstance(obj, dict) and '__typename__' in obj: | |
return import_(obj['__typename__']) | |
return obj | |
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode]) | |
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode]) | |
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode]) | |
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode]) | |
### End of json-tricks patch ### | |
class Translatable(abc.ABC): | |
""" | |
Inherit this class and implement ``translate`` when the inner class needs a different | |
parameter from the wrapper class in its init function. | |
""" | |
def _translate(self) -> Any: | |
pass | |
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False): | |
class wrapper(cls): | |
def __init__(self, *args, **kwargs): | |
if reset_mutation_uid: | |
reset_uid('mutation') | |
if store_init_parameters: | |
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:] | |
full_args = {} | |
full_args.update(kwargs) | |
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.' | |
for argname, value in zip(argname_list, args): | |
full_args[argname] = value | |
# translate parameters | |
args = list(args) | |
for i, value in enumerate(args): | |
if isinstance(value, Translatable): | |
args[i] = value._translate() | |
for i, value in kwargs.items(): | |
if isinstance(value, Translatable): | |
kwargs[i] = value._translate() | |
self._init_parameters = full_args | |
else: | |
self._init_parameters = {} | |
super().__init__(*args, **kwargs) | |
wrapper.__module__ = get_module_name(cls) | |
wrapper.__name__ = cls.__name__ | |
wrapper.__qualname__ = cls.__qualname__ | |
wrapper.__init__.__doc__ = cls.__init__.__doc__ | |
return wrapper | |
def serialize_cls(cls): | |
""" | |
To create an serializable class. | |
""" | |
return _create_wrapper_cls(cls) | |
def transparent_serialize(cls): | |
""" | |
Wrap a module but does not record parameters. For internal use only. | |
""" | |
return _create_wrapper_cls(cls, store_init_parameters=False) | |
def serialize(cls, *args, **kwargs): | |
""" | |
To create an serializable instance inline without decorator. For example, | |
.. code-block:: python | |
self.op = serialize(MyCustomOp, hidden_units=128) | |
""" | |
return serialize_cls(cls)(*args, **kwargs) | |
def basic_unit(cls): | |
""" | |
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able. | |
""" | |
import torch.nn as nn | |
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' | |
return serialize_cls(cls) | |
def model_wrapper(cls): | |
""" | |
Wrap the model if you are using pure-python execution engine. | |
The wrapper serves two purposes: | |
1. Capture the init parameters of python class so that it can be re-instantiated in another process. | |
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios. | |
""" | |
return _create_wrapper_cls(cls, reset_mutation_uid=True) | |