Spaces:
Running
Running
"""isort:skip_file""" | |
from pickle import ( # type: ignore[attr-defined] | |
_compat_pickle, | |
_extension_registry, | |
_getattribute, | |
_Pickler, | |
EXT1, | |
EXT2, | |
EXT4, | |
GLOBAL, | |
Pickler, | |
PicklingError, | |
STACK_GLOBAL, | |
) | |
from struct import pack | |
from types import FunctionType | |
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer | |
class PackagePickler(_Pickler): | |
"""Package-aware pickler. | |
This behaves the same as a normal pickler, except it uses an `Importer` | |
to find objects and modules to save. | |
""" | |
def __init__(self, importer: Importer, *args, **kwargs): | |
self.importer = importer | |
super().__init__(*args, **kwargs) | |
# Make sure the dispatch table copied from _Pickler is up-to-date. | |
# Previous issues have been encountered where a library (e.g. dill) | |
# mutate _Pickler.dispatch, PackagePickler makes a copy when this lib | |
# is imported, then the offending library removes its dispatch entries, | |
# leaving PackagePickler with a stale dispatch table that may cause | |
# unwanted behavior. | |
self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc] | |
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment] | |
def save_global(self, obj, name=None): | |
# unfortunately the pickler code is factored in a way that | |
# forces us to copy/paste this function. The only change is marked | |
# CHANGED below. | |
write = self.write # type: ignore[attr-defined] | |
memo = self.memo # type: ignore[attr-defined] | |
# CHANGED: import module from module environment instead of __import__ | |
try: | |
module_name, name = self.importer.get_name(obj, name) | |
except (ObjNotFoundError, ObjMismatchError) as err: | |
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None | |
module = self.importer.import_module(module_name) | |
_, parent = _getattribute(module, name) | |
# END CHANGED | |
if self.proto >= 2: # type: ignore[attr-defined] | |
code = _extension_registry.get((module_name, name)) | |
if code: | |
assert code > 0 | |
if code <= 0xFF: | |
write(EXT1 + pack("<B", code)) | |
elif code <= 0xFFFF: | |
write(EXT2 + pack("<H", code)) | |
else: | |
write(EXT4 + pack("<i", code)) | |
return | |
lastname = name.rpartition(".")[2] | |
if parent is module: | |
name = lastname | |
# Non-ASCII identifiers are supported only with protocols >= 3. | |
if self.proto >= 4: # type: ignore[attr-defined] | |
self.save(module_name) # type: ignore[attr-defined] | |
self.save(name) # type: ignore[attr-defined] | |
write(STACK_GLOBAL) | |
elif parent is not module: | |
self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined] | |
elif self.proto >= 3: # type: ignore[attr-defined] | |
write( | |
GLOBAL | |
+ bytes(module_name, "utf-8") | |
+ b"\n" | |
+ bytes(name, "utf-8") | |
+ b"\n" | |
) | |
else: | |
if self.fix_imports: # type: ignore[attr-defined] | |
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING | |
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING | |
if (module_name, name) in r_name_mapping: | |
module_name, name = r_name_mapping[(module_name, name)] | |
elif module_name in r_import_mapping: | |
module_name = r_import_mapping[module_name] | |
try: | |
write( | |
GLOBAL | |
+ bytes(module_name, "ascii") | |
+ b"\n" | |
+ bytes(name, "ascii") | |
+ b"\n" | |
) | |
except UnicodeEncodeError: | |
raise PicklingError( | |
"can't pickle global identifier '%s.%s' using " | |
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined] | |
) from None | |
self.memoize(obj) # type: ignore[attr-defined] | |
def create_pickler(data_buf, importer, protocol=4): | |
if importer is sys_importer: | |
# if we are using the normal import library system, then | |
# we can use the C implementation of pickle which is faster | |
return Pickler(data_buf, protocol=protocol) | |
else: | |
return PackagePickler(importer, data_buf, protocol=protocol) | |