|
"""Pickle related utilities. Perhaps this should be called 'can'.""" |
|
|
|
|
|
|
|
import copy |
|
import pickle |
|
import sys |
|
import typing |
|
import warnings |
|
from types import FunctionType |
|
|
|
|
|
try: |
|
from ipyparallel.serialize import codeutil |
|
except ImportError: |
|
pass |
|
from traitlets.log import get_logger |
|
from traitlets.utils.importstring import import_item |
|
|
|
warnings.warn( |
|
"ipykernel.pickleutil is deprecated. It has moved to ipyparallel.", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
|
|
buffer = memoryview |
|
class_type = type |
|
|
|
PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL |
|
|
|
|
|
def _get_cell_type(a=None): |
|
"""the type of a closure cell doesn't seem to be importable, |
|
so just create one |
|
""" |
|
|
|
def inner(): |
|
return a |
|
|
|
return type(inner.__closure__[0]) |
|
|
|
|
|
cell_type = _get_cell_type() |
|
|
|
|
|
|
|
|
|
|
|
|
|
def interactive(f): |
|
"""decorator for making functions appear as interactively defined. |
|
This results in the function being linked to the user_ns as globals() |
|
instead of the module globals(). |
|
""" |
|
|
|
|
|
|
|
if isinstance(f, FunctionType): |
|
mainmod = __import__("__main__") |
|
f = FunctionType( |
|
f.__code__, |
|
mainmod.__dict__, |
|
f.__name__, |
|
f.__defaults__, |
|
) |
|
|
|
f.__module__ = "__main__" |
|
return f |
|
|
|
|
|
def use_dill(): |
|
"""use dill to expand serialization support |
|
|
|
adds support for object methods and closures to serialization. |
|
""" |
|
|
|
import dill |
|
|
|
|
|
|
|
|
|
global pickle |
|
pickle = dill |
|
|
|
try: |
|
from ipykernel import serialize |
|
except ImportError: |
|
pass |
|
else: |
|
serialize.pickle = dill |
|
|
|
|
|
can_map.pop(FunctionType, None) |
|
|
|
|
|
def use_cloudpickle(): |
|
"""use cloudpickle to expand serialization support |
|
|
|
adds support for object methods and closures to serialization. |
|
""" |
|
import cloudpickle |
|
|
|
global pickle |
|
pickle = cloudpickle |
|
|
|
try: |
|
from ipykernel import serialize |
|
except ImportError: |
|
pass |
|
else: |
|
serialize.pickle = cloudpickle |
|
|
|
|
|
can_map.pop(FunctionType, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CannedObject: |
|
"""A canned object.""" |
|
|
|
def __init__(self, obj, keys=None, hook=None): |
|
"""can an object for safe pickling |
|
|
|
Parameters |
|
---------- |
|
obj |
|
The object to be canned |
|
keys : list (optional) |
|
list of attribute names that will be explicitly canned / uncanned |
|
hook : callable (optional) |
|
An optional extra callable, |
|
which can do additional processing of the uncanned object. |
|
|
|
Notes |
|
----- |
|
large data may be offloaded into the buffers list, |
|
used for zero-copy transfers. |
|
""" |
|
self.keys = keys or [] |
|
self.obj = copy.copy(obj) |
|
self.hook = can(hook) |
|
for key in keys: |
|
setattr(self.obj, key, can(getattr(obj, key))) |
|
|
|
self.buffers = [] |
|
|
|
def get_object(self, g=None): |
|
"""Get an object.""" |
|
if g is None: |
|
g = {} |
|
obj = self.obj |
|
for key in self.keys: |
|
setattr(obj, key, uncan(getattr(obj, key), g)) |
|
|
|
if self.hook: |
|
self.hook = uncan(self.hook, g) |
|
self.hook(obj, g) |
|
return self.obj |
|
|
|
|
|
class Reference(CannedObject): |
|
"""object for wrapping a remote reference by name.""" |
|
|
|
def __init__(self, name): |
|
"""Initialize the reference.""" |
|
if not isinstance(name, str): |
|
raise TypeError("illegal name: %r" % name) |
|
self.name = name |
|
self.buffers = [] |
|
|
|
def __repr__(self): |
|
"""Get the string repr of the reference.""" |
|
return "<Reference: %r>" % self.name |
|
|
|
def get_object(self, g=None): |
|
"""Get an object in the reference.""" |
|
if g is None: |
|
g = {} |
|
|
|
return eval(self.name, g) |
|
|
|
|
|
class CannedCell(CannedObject): |
|
"""Can a closure cell""" |
|
|
|
def __init__(self, cell): |
|
"""Initialize the canned cell.""" |
|
self.cell_contents = can(cell.cell_contents) |
|
|
|
def get_object(self, g=None): |
|
"""Get an object in the cell.""" |
|
cell_contents = uncan(self.cell_contents, g) |
|
|
|
def inner(): |
|
"""Inner function.""" |
|
return cell_contents |
|
|
|
return inner.__closure__[0] |
|
|
|
|
|
class CannedFunction(CannedObject): |
|
"""Can a function.""" |
|
|
|
def __init__(self, f): |
|
"""Initialize the can""" |
|
self._check_type(f) |
|
self.code = f.__code__ |
|
self.defaults: typing.Optional[typing.List[typing.Any]] |
|
if f.__defaults__: |
|
self.defaults = [can(fd) for fd in f.__defaults__] |
|
else: |
|
self.defaults = None |
|
|
|
self.closure: typing.Any |
|
closure = f.__closure__ |
|
if closure: |
|
self.closure = tuple(can(cell) for cell in closure) |
|
else: |
|
self.closure = None |
|
|
|
self.module = f.__module__ or "__main__" |
|
self.__name__ = f.__name__ |
|
self.buffers = [] |
|
|
|
def _check_type(self, obj): |
|
assert isinstance(obj, FunctionType), "Not a function type" |
|
|
|
def get_object(self, g=None): |
|
"""Get an object out of the can.""" |
|
|
|
if not self.module.startswith("__"): |
|
__import__(self.module) |
|
g = sys.modules[self.module].__dict__ |
|
|
|
if g is None: |
|
g = {} |
|
defaults = tuple(uncan(cfd, g) for cfd in self.defaults) if self.defaults else None |
|
closure = tuple(uncan(cell, g) for cell in self.closure) if self.closure else None |
|
return FunctionType(self.code, g, self.__name__, defaults, closure) |
|
|
|
|
|
class CannedClass(CannedObject): |
|
"""A canned class object.""" |
|
|
|
def __init__(self, cls): |
|
"""Initialize the can.""" |
|
self._check_type(cls) |
|
self.name = cls.__name__ |
|
self.old_style = not isinstance(cls, type) |
|
self._canned_dict = {} |
|
for k, v in cls.__dict__.items(): |
|
if k not in ("__weakref__", "__dict__"): |
|
self._canned_dict[k] = can(v) |
|
mro = [] if self.old_style else cls.mro() |
|
|
|
self.parents = [can(c) for c in mro[1:]] |
|
self.buffers = [] |
|
|
|
def _check_type(self, obj): |
|
assert isinstance(obj, class_type), "Not a class type" |
|
|
|
def get_object(self, g=None): |
|
"""Get an object from the can.""" |
|
parents = tuple(uncan(p, g) for p in self.parents) |
|
return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) |
|
|
|
|
|
class CannedArray(CannedObject): |
|
"""A canned numpy array.""" |
|
|
|
def __init__(self, obj): |
|
"""Initialize the can.""" |
|
from numpy import ascontiguousarray |
|
|
|
self.shape = obj.shape |
|
self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str |
|
self.pickled = False |
|
if sum(obj.shape) == 0: |
|
self.pickled = True |
|
elif obj.dtype == "O": |
|
|
|
self.pickled = True |
|
elif obj.dtype.fields and any(dt == "O" for dt, sz in obj.dtype.fields.values()): |
|
self.pickled = True |
|
if self.pickled: |
|
|
|
self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)] |
|
else: |
|
|
|
obj = ascontiguousarray(obj, dtype=None) |
|
self.buffers = [buffer(obj)] |
|
|
|
def get_object(self, g=None): |
|
"""Get the object.""" |
|
from numpy import frombuffer |
|
|
|
data = self.buffers[0] |
|
if self.pickled: |
|
|
|
return pickle.loads(data) |
|
return frombuffer(data, dtype=self.dtype).reshape(self.shape) |
|
|
|
|
|
class CannedBytes(CannedObject): |
|
"""A canned bytes object.""" |
|
|
|
@staticmethod |
|
def wrap(buf: typing.Union[memoryview, bytes, typing.SupportsBytes]) -> bytes: |
|
"""Cast a buffer or memoryview object to bytes""" |
|
if isinstance(buf, memoryview): |
|
return buf.tobytes() |
|
if not isinstance(buf, bytes): |
|
return bytes(buf) |
|
return buf |
|
|
|
def __init__(self, obj): |
|
"""Initialize the can.""" |
|
self.buffers = [obj] |
|
|
|
def get_object(self, g=None): |
|
"""Get the canned object.""" |
|
data = self.buffers[0] |
|
return self.wrap(data) |
|
|
|
|
|
class CannedBuffer(CannedBytes): |
|
"""A canned buffer.""" |
|
|
|
wrap = buffer |
|
|
|
|
|
class CannedMemoryView(CannedBytes): |
|
"""A canned memory view.""" |
|
|
|
wrap = memoryview |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _import_mapping(mapping, original=None): |
|
"""import any string-keys in a type mapping""" |
|
log = get_logger() |
|
log.debug("Importing canning map") |
|
for key, _ in list(mapping.items()): |
|
if isinstance(key, str): |
|
try: |
|
cls = import_item(key) |
|
except Exception: |
|
if original and key not in original: |
|
|
|
log.error("canning class not importable: %r", key, exc_info=True) |
|
mapping.pop(key) |
|
else: |
|
mapping[cls] = mapping.pop(key) |
|
|
|
|
|
def istype(obj, check): |
|
"""like isinstance(obj, check), but strict |
|
|
|
This won't catch subclasses. |
|
""" |
|
if isinstance(check, tuple): |
|
return any(type(obj) is cls for cls in check) |
|
return type(obj) is check |
|
|
|
|
|
def can(obj): |
|
"""prepare an object for pickling""" |
|
|
|
import_needed = False |
|
|
|
for cls, canner in can_map.items(): |
|
if isinstance(cls, str): |
|
import_needed = True |
|
break |
|
if istype(obj, cls): |
|
return canner(obj) |
|
|
|
if import_needed: |
|
|
|
|
|
_import_mapping(can_map, _original_can_map) |
|
return can(obj) |
|
|
|
return obj |
|
|
|
|
|
def can_class(obj): |
|
"""Can a class object.""" |
|
if isinstance(obj, class_type) and obj.__module__ == "__main__": |
|
return CannedClass(obj) |
|
return obj |
|
|
|
|
|
def can_dict(obj): |
|
"""can the *values* of a dict""" |
|
if istype(obj, dict): |
|
newobj = {} |
|
for k, v in obj.items(): |
|
newobj[k] = can(v) |
|
return newobj |
|
return obj |
|
|
|
|
|
sequence_types = (list, tuple, set) |
|
|
|
|
|
def can_sequence(obj): |
|
"""can the elements of a sequence""" |
|
if istype(obj, sequence_types): |
|
t = type(obj) |
|
return t([can(i) for i in obj]) |
|
return obj |
|
|
|
|
|
def uncan(obj, g=None): |
|
"""invert canning""" |
|
|
|
import_needed = False |
|
for cls, uncanner in uncan_map.items(): |
|
if isinstance(cls, str): |
|
import_needed = True |
|
break |
|
if isinstance(obj, cls): |
|
return uncanner(obj, g) |
|
|
|
if import_needed: |
|
|
|
|
|
_import_mapping(uncan_map, _original_uncan_map) |
|
return uncan(obj, g) |
|
|
|
return obj |
|
|
|
|
|
def uncan_dict(obj, g=None): |
|
"""Uncan a dict object.""" |
|
if istype(obj, dict): |
|
newobj = {} |
|
for k, v in obj.items(): |
|
newobj[k] = uncan(v, g) |
|
return newobj |
|
return obj |
|
|
|
|
|
def uncan_sequence(obj, g=None): |
|
"""Uncan a sequence.""" |
|
if istype(obj, sequence_types): |
|
t = type(obj) |
|
return t([uncan(i, g) for i in obj]) |
|
return obj |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
can_map = { |
|
"numpy.ndarray": CannedArray, |
|
FunctionType: CannedFunction, |
|
bytes: CannedBytes, |
|
memoryview: CannedMemoryView, |
|
cell_type: CannedCell, |
|
class_type: can_class, |
|
} |
|
if buffer is not memoryview: |
|
can_map[buffer] = CannedBuffer |
|
|
|
uncan_map: typing.Dict[type, typing.Any] = { |
|
CannedObject: lambda obj, g: obj.get_object(g), |
|
dict: uncan_dict, |
|
} |
|
|
|
|
|
_original_can_map = can_map.copy() |
|
_original_uncan_map = uncan_map.copy() |
|
|