Spaces:
Running
Running
import builtins | |
import importlib | |
import importlib.machinery | |
import inspect | |
import io | |
import linecache | |
import os | |
import types | |
from contextlib import contextmanager | |
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union | |
from weakref import WeakValueDictionary | |
import torch | |
from torch.serialization import _get_restore_location, _maybe_decode_ascii | |
from ._directory_reader import DirectoryReader | |
from ._importlib import ( | |
_calc___package__, | |
_normalize_line_endings, | |
_normalize_path, | |
_resolve_name, | |
_sanity_check, | |
) | |
from ._mangling import demangle, PackageMangler | |
from ._package_unpickler import PackageUnpickler | |
from .file_structure_representation import _create_directory_from_file_list, Directory | |
from .glob_group import GlobPattern | |
from .importer import Importer | |
__all__ = ["PackageImporter"] | |
# This is a list of imports that are implicitly allowed even if they haven't | |
# been marked as extern. This is to work around the fact that Torch implicitly | |
# depends on numpy and package can't track it. | |
# https://github.com/pytorch/MultiPy/issues/46 | |
IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [ | |
"numpy", | |
"numpy.core", | |
"numpy.core._multiarray_umath", | |
# FX GraphModule might depend on builtins module and users usually | |
# don't extern builtins. Here we import it here by default. | |
"builtins", | |
] | |
class PackageImporter(Importer): | |
"""Importers allow you to load code written to packages by :class:`PackageExporter`. | |
Code is loaded in a hermetic way, using files from the package | |
rather than the normal python import system. This allows | |
for the packaging of PyTorch model code and data so that it can be run | |
on a server or used in the future for transfer learning. | |
The importer for packages ensures that code in the module can only be loaded from | |
within the package, except for modules explicitly listed as external during export. | |
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. | |
This prevents "implicit" dependencies where the package runs locally because it is importing | |
a locally-installed package, but then fails when the package is copied to another machine. | |
""" | |
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but | |
local to this importer. | |
""" | |
modules: Dict[str, types.ModuleType] | |
def __init__( | |
self, | |
file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO], | |
module_allowed: Callable[[str], bool] = lambda module_name: True, | |
): | |
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules | |
allowed by ``module_allowed`` | |
Args: | |
file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), | |
a string, or an ``os.PathLike`` object containing a filename. | |
module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module | |
should be allowed. Can be used to ensure packages loaded do not depend on modules that the server | |
does not support. Defaults to allowing anything. | |
Raises: | |
ImportError: If the package will use a disallowed module. | |
""" | |
torch._C._log_api_usage_once("torch.package.PackageImporter") | |
self.zip_reader: Any | |
if isinstance(file_or_buffer, torch._C.PyTorchFileReader): | |
self.filename = "<pytorch_file_reader>" | |
self.zip_reader = file_or_buffer | |
elif isinstance(file_or_buffer, (os.PathLike, str)): | |
self.filename = os.fspath(file_or_buffer) | |
if not os.path.isdir(self.filename): | |
self.zip_reader = torch._C.PyTorchFileReader(self.filename) | |
else: | |
self.zip_reader = DirectoryReader(self.filename) | |
else: | |
self.filename = "<binary>" | |
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer) | |
torch._C._log_api_usage_metadata( | |
"torch.package.PackageImporter.metadata", | |
{ | |
"serialization_id": self.zip_reader.serialization_id(), | |
"file_name": self.filename, | |
}, | |
) | |
self.root = _PackageNode(None) | |
self.modules = {} | |
self.extern_modules = self._read_extern() | |
for extern_module in self.extern_modules: | |
if not module_allowed(extern_module): | |
raise ImportError( | |
f"package '{file_or_buffer}' needs the external module '{extern_module}' " | |
f"but that module has been disallowed" | |
) | |
self._add_extern(extern_module) | |
for fname in self.zip_reader.get_all_records(): | |
self._add_file(fname) | |
self.patched_builtins = builtins.__dict__.copy() | |
self.patched_builtins["__import__"] = self.__import__ | |
# Allow packaged modules to reference their PackageImporter | |
self.modules["torch_package_importer"] = self # type: ignore[assignment] | |
self._mangler = PackageMangler() | |
# used for reduce deserializaiton | |
self.storage_context: Any = None | |
self.last_map_location = None | |
# used for torch.serialization._load | |
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs) | |
def import_module(self, name: str, package=None): | |
"""Load a module from the package if it hasn't already been loaded, and then return | |
the module. Modules are loaded locally | |
to the importer and will appear in ``self.modules`` rather than ``sys.modules``. | |
Args: | |
name (str): Fully qualified name of the module to load. | |
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``. | |
Returns: | |
types.ModuleType: The (possibly already) loaded module. | |
""" | |
# We should always be able to support importing modules from this package. | |
# This is to support something like: | |
# obj = importer.load_pickle(...) | |
# importer.import_module(obj.__module__) <- this string will be mangled | |
# | |
# Note that _mangler.demangle will not demangle any module names | |
# produced by a different PackageImporter instance. | |
name = self._mangler.demangle(name) | |
return self._gcd_import(name) | |
def load_binary(self, package: str, resource: str) -> bytes: | |
"""Load raw bytes. | |
Args: | |
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). | |
resource (str): The unique name for the resource. | |
Returns: | |
bytes: The loaded data. | |
""" | |
path = self._zipfile_path(package, resource) | |
return self.zip_reader.get_record(path) | |
def load_text( | |
self, | |
package: str, | |
resource: str, | |
encoding: str = "utf-8", | |
errors: str = "strict", | |
) -> str: | |
"""Load a string. | |
Args: | |
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). | |
resource (str): The unique name for the resource. | |
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``. | |
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``. | |
Returns: | |
str: The loaded text. | |
""" | |
data = self.load_binary(package, resource) | |
return data.decode(encoding, errors) | |
def load_pickle(self, package: str, resource: str, map_location=None) -> Any: | |
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects | |
using :meth:`import_module`. | |
Args: | |
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). | |
resource (str): The unique name for the resource. | |
map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``. | |
Returns: | |
Any: The unpickled object. | |
""" | |
pickle_file = self._zipfile_path(package, resource) | |
restore_location = _get_restore_location(map_location) | |
loaded_storages = {} | |
loaded_reduces = {} | |
storage_context = torch._C.DeserializationStorageContext() | |
def load_tensor(dtype, size, key, location, restore_location): | |
name = f"{key}.storage" | |
if storage_context.has_storage(name): | |
storage = storage_context.get_storage(name, dtype)._typed_storage() | |
else: | |
tensor = self.zip_reader.get_storage_from_record( | |
".data/" + name, size, dtype | |
) | |
if isinstance(self.zip_reader, torch._C.PyTorchFileReader): | |
storage_context.add_storage(name, tensor) | |
storage = tensor._typed_storage() | |
loaded_storages[key] = restore_location(storage, location) | |
def persistent_load(saved_id): | |
assert isinstance(saved_id, tuple) | |
typename = _maybe_decode_ascii(saved_id[0]) | |
data = saved_id[1:] | |
if typename == "storage": | |
storage_type, key, location, size = data | |
dtype = storage_type.dtype | |
if key not in loaded_storages: | |
load_tensor( | |
dtype, | |
size, | |
key, | |
_maybe_decode_ascii(location), | |
restore_location, | |
) | |
storage = loaded_storages[key] | |
# TODO: Once we decide to break serialization FC, we can | |
# stop wrapping with TypedStorage | |
return torch.storage.TypedStorage( | |
wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True | |
) | |
elif typename == "reduce_package": | |
# to fix BC breaking change, objects on this load path | |
# will be loaded multiple times erroneously | |
if len(data) == 2: | |
func, args = data | |
return func(self, *args) | |
reduce_id, func, args = data | |
if reduce_id not in loaded_reduces: | |
loaded_reduces[reduce_id] = func(self, *args) | |
return loaded_reduces[reduce_id] | |
else: | |
f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'" | |
# Load the data (which may in turn use `persistent_load` to load tensors) | |
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file)) | |
unpickler = self.Unpickler(data_file) | |
unpickler.persistent_load = persistent_load # type: ignore[assignment] | |
def set_deserialization_context(): | |
# to let reduce_package access deserializaiton context | |
self.storage_context = storage_context | |
self.last_map_location = map_location | |
try: | |
yield | |
finally: | |
self.storage_context = None | |
self.last_map_location = None | |
with set_deserialization_context(): | |
result = unpickler.load() | |
# TODO from zdevito: | |
# This stateful weird function will need to be removed in our efforts | |
# to unify the format. It has a race condition if multiple python | |
# threads try to read independent files | |
torch._utils._validate_loaded_sparse_tensors() | |
return result | |
def id(self): | |
""" | |
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances. | |
Looks like:: | |
<torch_package_0> | |
""" | |
return self._mangler.parent_name() | |
def file_structure( | |
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = () | |
) -> Directory: | |
"""Returns a file structure representation of package's zipfile. | |
Args: | |
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings | |
for the names of the files to be included in the zipfile representation. This can also be | |
a glob-style pattern, as described in :meth:`PackageExporter.mock` | |
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. | |
Returns: | |
:class:`Directory` | |
""" | |
return _create_directory_from_file_list( | |
self.filename, self.zip_reader.get_all_records(), include, exclude | |
) | |
def python_version(self): | |
"""Returns the version of python that was used to create this package. | |
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock | |
file later on. | |
Returns: | |
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package | |
""" | |
python_version_path = ".data/python_version" | |
return ( | |
self.zip_reader.get_record(python_version_path).decode("utf-8").strip() | |
if self.zip_reader.has_record(python_version_path) | |
else None | |
) | |
def _read_extern(self): | |
return ( | |
self.zip_reader.get_record(".data/extern_modules") | |
.decode("utf-8") | |
.splitlines(keepends=False) | |
) | |
def _make_module( | |
self, name: str, filename: Optional[str], is_package: bool, parent: str | |
): | |
mangled_filename = self._mangler.mangle(filename) if filename else None | |
spec = importlib.machinery.ModuleSpec( | |
name, | |
self, # type: ignore[arg-type] | |
origin="<package_importer>", | |
is_package=is_package, | |
) | |
module = importlib.util.module_from_spec(spec) | |
self.modules[name] = module | |
module.__name__ = self._mangler.mangle(name) | |
ns = module.__dict__ | |
ns["__spec__"] = spec | |
ns["__loader__"] = self | |
ns["__file__"] = mangled_filename | |
ns["__cached__"] = None | |
ns["__builtins__"] = self.patched_builtins | |
ns["__torch_package__"] = True | |
# Add this module to our private global registry. It should be unique due to mangling. | |
assert module.__name__ not in _package_imported_modules | |
_package_imported_modules[module.__name__] = module | |
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to | |
# access sys.modules | |
self._install_on_parent(parent, name, module) | |
if filename is not None: | |
assert mangled_filename is not None | |
# pre-emptively install the source in `linecache` so that stack traces, | |
# `inspect`, etc. work. | |
assert filename not in linecache.cache # type: ignore[attr-defined] | |
linecache.lazycache(mangled_filename, ns) | |
code = self._compile_source(filename, mangled_filename) | |
exec(code, ns) | |
return module | |
def _load_module(self, name: str, parent: str): | |
cur: _PathNode = self.root | |
for atom in name.split("."): | |
if not isinstance(cur, _PackageNode) or atom not in cur.children: | |
if name in IMPLICIT_IMPORT_ALLOWLIST: | |
module = self.modules[name] = importlib.import_module(name) | |
return module | |
raise ModuleNotFoundError( | |
f'No module named "{name}" in self-contained archive "{self.filename}"' | |
f" and the module is also not in the list of allowed external modules: {self.extern_modules}", | |
name=name, | |
) | |
cur = cur.children[atom] | |
if isinstance(cur, _ExternNode): | |
module = self.modules[name] = importlib.import_module(name) | |
return module | |
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] | |
def _compile_source(self, fullpath: str, mangled_filename: str): | |
source = self.zip_reader.get_record(fullpath) | |
source = _normalize_line_endings(source) | |
return compile(source, mangled_filename, "exec", dont_inherit=True) | |
# note: named `get_source` so that linecache can find the source | |
# when this is the __loader__ of a module. | |
def get_source(self, module_name) -> str: | |
# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here. | |
module = self.import_module(demangle(module_name)) | |
return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8") | |
# note: named `get_resource_reader` so that importlib.resources can find it. | |
# This is otherwise considered an internal method. | |
def get_resource_reader(self, fullname): | |
try: | |
package = self._get_package(fullname) | |
except ImportError: | |
return None | |
if package.__loader__ is not self: | |
return None | |
return _PackageResourceReader(self, fullname) | |
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): | |
if not parent: | |
return | |
# Set the module as an attribute on its parent. | |
parent_module = self.modules[parent] | |
if parent_module.__loader__ is self: | |
setattr(parent_module, name.rpartition(".")[2], module) | |
# note: copied from cpython's import code, with call to create module replaced with _make_module | |
def _do_find_and_load(self, name): | |
path = None | |
parent = name.rpartition(".")[0] | |
module_name_no_parent = name.rpartition(".")[-1] | |
if parent: | |
if parent not in self.modules: | |
self._gcd_import(parent) | |
# Crazy side-effects! | |
if name in self.modules: | |
return self.modules[name] | |
parent_module = self.modules[parent] | |
try: | |
path = parent_module.__path__ # type: ignore[attr-defined] | |
except AttributeError: | |
# when we attempt to import a package only containing pybinded files, | |
# the parent directory isn't always a package as defined by python, | |
# so we search if the package is actually there or not before calling the error. | |
if isinstance( | |
parent_module.__loader__, | |
importlib.machinery.ExtensionFileLoader, | |
): | |
if name not in self.extern_modules: | |
msg = ( | |
_ERR_MSG | |
+ "; {!r} is a c extension module which was not externed. C extension modules \ | |
need to be externed by the PackageExporter in order to be used as we do not support interning them.}." | |
).format(name, name) | |
raise ModuleNotFoundError(msg, name=name) from None | |
if not isinstance( | |
parent_module.__dict__.get(module_name_no_parent), | |
types.ModuleType, | |
): | |
msg = ( | |
_ERR_MSG | |
+ "; {!r} is a c extension package which does not contain {!r}." | |
).format(name, parent, name) | |
raise ModuleNotFoundError(msg, name=name) from None | |
else: | |
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent) | |
raise ModuleNotFoundError(msg, name=name) from None | |
module = self._load_module(name, parent) | |
self._install_on_parent(parent, name, module) | |
return module | |
# note: copied from cpython's import code | |
def _find_and_load(self, name): | |
module = self.modules.get(name, _NEEDS_LOADING) | |
if module is _NEEDS_LOADING: | |
return self._do_find_and_load(name) | |
if module is None: | |
message = f"import of {name} halted; None in sys.modules" | |
raise ModuleNotFoundError(message, name=name) | |
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's | |
# creation of fake submodules via the hacking of sys.modules is not import | |
# friendly | |
if name == "os": | |
self.modules["os.path"] = cast(Any, module).path | |
elif name == "typing": | |
self.modules["typing.io"] = cast(Any, module).io | |
self.modules["typing.re"] = cast(Any, module).re | |
return module | |
def _gcd_import(self, name, package=None, level=0): | |
"""Import and return the module based on its name, the package the call is | |
being made from, and the level adjustment. | |
This function represents the greatest common denominator of functionality | |
between import_module and __import__. This includes setting __package__ if | |
the loader did not. | |
""" | |
_sanity_check(name, package, level) | |
if level > 0: | |
name = _resolve_name(name, package, level) | |
return self._find_and_load(name) | |
# note: copied from cpython's import code | |
def _handle_fromlist(self, module, fromlist, *, recursive=False): | |
"""Figure out what __import__ should return. | |
The import_ parameter is a callable which takes the name of module to | |
import. It is required to decouple the function from assuming importlib's | |
import implementation is desired. | |
""" | |
module_name = demangle(module.__name__) | |
# The hell that is fromlist ... | |
# If a package was imported, try to import stuff from fromlist. | |
if hasattr(module, "__path__"): | |
for x in fromlist: | |
if not isinstance(x, str): | |
if recursive: | |
where = module_name + ".__all__" | |
else: | |
where = "``from list''" | |
raise TypeError( | |
f"Item in {where} must be str, " f"not {type(x).__name__}" | |
) | |
elif x == "*": | |
if not recursive and hasattr(module, "__all__"): | |
self._handle_fromlist(module, module.__all__, recursive=True) | |
elif not hasattr(module, x): | |
from_name = f"{module_name}.{x}" | |
try: | |
self._gcd_import(from_name) | |
except ModuleNotFoundError as exc: | |
# Backwards-compatibility dictates we ignore failed | |
# imports triggered by fromlist for modules that don't | |
# exist. | |
if ( | |
exc.name == from_name | |
and self.modules.get(from_name, _NEEDS_LOADING) is not None | |
): | |
continue | |
raise | |
return module | |
def __import__(self, name, globals=None, locals=None, fromlist=(), level=0): | |
if level == 0: | |
module = self._gcd_import(name) | |
else: | |
globals_ = globals if globals is not None else {} | |
package = _calc___package__(globals_) | |
module = self._gcd_import(name, package, level) | |
if not fromlist: | |
# Return up to the first dot in 'name'. This is complicated by the fact | |
# that 'name' may be relative. | |
if level == 0: | |
return self._gcd_import(name.partition(".")[0]) | |
elif not name: | |
return module | |
else: | |
# Figure out where to slice the module's name up to the first dot | |
# in 'name'. | |
cut_off = len(name) - len(name.partition(".")[0]) | |
# Slice end needs to be positive to alleviate need to special-case | |
# when ``'.' not in name``. | |
module_name = demangle(module.__name__) | |
return self.modules[module_name[: len(module_name) - cut_off]] | |
else: | |
return self._handle_fromlist(module, fromlist) | |
def _get_package(self, package): | |
"""Take a package name or module object and return the module. | |
If a name, the module is imported. If the passed or imported module | |
object is not a package, raise an exception. | |
""" | |
if hasattr(package, "__spec__"): | |
if package.__spec__.submodule_search_locations is None: | |
raise TypeError(f"{package.__spec__.name!r} is not a package") | |
else: | |
return package | |
else: | |
module = self.import_module(package) | |
if module.__spec__.submodule_search_locations is None: | |
raise TypeError(f"{package!r} is not a package") | |
else: | |
return module | |
def _zipfile_path(self, package, resource=None): | |
package = self._get_package(package) | |
assert package.__loader__ is self | |
name = demangle(package.__name__) | |
if resource is not None: | |
resource = _normalize_path(resource) | |
return f"{name.replace('.', '/')}/{resource}" | |
else: | |
return f"{name.replace('.', '/')}" | |
def _get_or_create_package( | |
self, atoms: List[str] | |
) -> "Union[_PackageNode, _ExternNode]": | |
cur = self.root | |
for i, atom in enumerate(atoms): | |
node = cur.children.get(atom, None) | |
if node is None: | |
node = cur.children[atom] = _PackageNode(None) | |
if isinstance(node, _ExternNode): | |
return node | |
if isinstance(node, _ModuleNode): | |
name = ".".join(atoms[:i]) | |
raise ImportError( | |
f"inconsistent module structure. module {name} is not a package, but has submodules" | |
) | |
assert isinstance(node, _PackageNode) | |
cur = node | |
return cur | |
def _add_file(self, filename: str): | |
"""Assembles a Python module out of the given file. Will ignore files in the .data directory. | |
Args: | |
filename (str): the name of the file inside of the package archive to be added | |
""" | |
*prefix, last = filename.split("/") | |
if len(prefix) > 1 and prefix[0] == ".data": | |
return | |
package = self._get_or_create_package(prefix) | |
if isinstance(package, _ExternNode): | |
raise ImportError( | |
f"inconsistent module structure. package contains a module file {filename}" | |
f" that is a subpackage of a module marked external." | |
) | |
if last == "__init__.py": | |
package.source_file = filename | |
elif last.endswith(".py"): | |
package_name = last[: -len(".py")] | |
package.children[package_name] = _ModuleNode(filename) | |
def _add_extern(self, extern_name: str): | |
*prefix, last = extern_name.split(".") | |
package = self._get_or_create_package(prefix) | |
if isinstance(package, _ExternNode): | |
return # the shorter extern covers this extern case | |
package.children[last] = _ExternNode() | |
_NEEDS_LOADING = object() | |
_ERR_MSG_PREFIX = "No module named " | |
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}" | |
class _PathNode: | |
pass | |
class _PackageNode(_PathNode): | |
def __init__(self, source_file: Optional[str]): | |
self.source_file = source_file | |
self.children: Dict[str, _PathNode] = {} | |
class _ModuleNode(_PathNode): | |
__slots__ = ["source_file"] | |
def __init__(self, source_file: str): | |
self.source_file = source_file | |
class _ExternNode(_PathNode): | |
pass | |
# A private global registry of all modules that have been package-imported. | |
_package_imported_modules: WeakValueDictionary = WeakValueDictionary() | |
# `inspect` by default only looks in `sys.modules` to find source files for classes. | |
# Patch it to check our private registry of package-imported modules as well. | |
_orig_getfile = inspect.getfile | |
def _patched_getfile(object): | |
if inspect.isclass(object): | |
if object.__module__ in _package_imported_modules: | |
return _package_imported_modules[object.__module__].__file__ | |
return _orig_getfile(object) | |
inspect.getfile = _patched_getfile | |
class _PackageResourceReader: | |
"""Private class used to support PackageImporter.get_resource_reader(). | |
Confirms to the importlib.abc.ResourceReader interface. Allowed to access | |
the innards of PackageImporter. | |
""" | |
def __init__(self, importer, fullname): | |
self.importer = importer | |
self.fullname = fullname | |
def open_resource(self, resource): | |
from io import BytesIO | |
return BytesIO(self.importer.load_binary(self.fullname, resource)) | |
def resource_path(self, resource): | |
# The contract for resource_path is that it either returns a concrete | |
# file system path or raises FileNotFoundError. | |
if isinstance( | |
self.importer.zip_reader, DirectoryReader | |
) and self.importer.zip_reader.has_record( | |
os.path.join(self.fullname, resource) | |
): | |
return os.path.join( | |
self.importer.zip_reader.directory, self.fullname, resource | |
) | |
raise FileNotFoundError | |
def is_resource(self, name): | |
path = self.importer._zipfile_path(self.fullname, name) | |
return self.importer.zip_reader.has_record(path) | |
def contents(self): | |
from pathlib import Path | |
filename = self.fullname.replace(".", "/") | |
fullname_path = Path(self.importer._zipfile_path(self.fullname)) | |
files = self.importer.zip_reader.get_all_records() | |
subdirs_seen = set() | |
for filename in files: | |
try: | |
relative = Path(filename).relative_to(fullname_path) | |
except ValueError: | |
continue | |
# If the path of the file (which is relative to the top of the zip | |
# namespace), relative to the package given when the resource | |
# reader was created, has a parent, then it's a name in a | |
# subdirectory and thus we skip it. | |
parent_name = relative.parent.name | |
if len(parent_name) == 0: | |
yield relative.name | |
elif parent_name not in subdirs_seen: | |
subdirs_seen.add(parent_name) | |
yield parent_name | |