|
from __future__ import annotations |
|
|
|
from functools import partial |
|
from importlib.metadata import entry_points |
|
from typing import TYPE_CHECKING, Any, Callable, Generic, cast |
|
from typing_extensions import TypeAliasType, TypeIs, TypeVar |
|
|
|
from altair.utils.deprecation import deprecated_warn |
|
|
|
if TYPE_CHECKING: |
|
from types import TracebackType |
|
|
|
T = TypeVar("T") |
|
R = TypeVar("R") |
|
Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,)) |
|
PluginT = TypeVar("PluginT", bound=Plugin[Any]) |
|
IsPlugin = Callable[[object], TypeIs[Plugin[Any]]] |
|
|
|
|
|
def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]: |
|
""" |
|
Converts a type to guard function. |
|
|
|
Added for compatibility with original `PluginRegistry` default. |
|
""" |
|
|
|
def func(obj: object, /) -> TypeIs[type[T]]: |
|
return isinstance(obj, tp) |
|
|
|
return func |
|
|
|
|
|
class NoSuchEntryPoint(Exception): |
|
def __init__(self, group, name): |
|
self.group = group |
|
self.name = name |
|
|
|
def __str__(self): |
|
return f"No {self.name!r} entry point found in group {self.group!r}" |
|
|
|
|
|
class PluginEnabler: |
|
""" |
|
Context manager for enabling plugins. |
|
|
|
This object lets you use enable() as a context manager to |
|
temporarily enable a given plugin:: |
|
|
|
with plugins.enable("name"): |
|
do_something() # 'name' plugin temporarily enabled |
|
# plugins back to original state |
|
""" |
|
|
|
def __init__(self, registry: PluginRegistry, name: str, **options): |
|
self.registry: PluginRegistry = registry |
|
self.name: str = name |
|
self.options: dict[str, Any] = options |
|
self.original_state: dict[str, Any] = registry._get_state() |
|
self.registry._enable(name, **options) |
|
|
|
def __enter__(self) -> PluginEnabler: |
|
return self |
|
|
|
def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None: |
|
self.registry._set_state(self.original_state) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.registry.__class__.__name__}.enable({self.name!r})" |
|
|
|
|
|
class PluginRegistry(Generic[PluginT, R]): |
|
""" |
|
A registry for plugins. |
|
|
|
This is a plugin registry that allows plugins to be loaded/registered |
|
in two ways: |
|
|
|
1. Through an explicit call to ``.register(name, value)``. |
|
2. By looking for other Python packages that are installed and provide |
|
a setuptools entry point group. |
|
|
|
When you create an instance of this class, provide the name of the |
|
entry point group to use:: |
|
|
|
reg = PluginRegister("my_entrypoint_group") |
|
|
|
""" |
|
|
|
|
|
|
|
entrypoint_err_messages: dict[str, str] = {} |
|
|
|
|
|
|
|
_global_settings: dict[str, Any] = {} |
|
|
|
def __init__( |
|
self, entry_point_group: str = "", plugin_type: IsPlugin = callable |
|
) -> None: |
|
""" |
|
Create a PluginRegistry for a named entry point group. |
|
|
|
Parameters |
|
---------- |
|
entry_point_group: str |
|
The name of the entry point group. |
|
plugin_type |
|
A type narrowing function that will optionally be used for runtime |
|
type checking loaded plugins. |
|
|
|
References |
|
---------- |
|
https://typing.readthedocs.io/en/latest/spec/narrowing.html |
|
""" |
|
self.entry_point_group: str = entry_point_group |
|
self.plugin_type: IsPlugin |
|
if plugin_type is not callable and isinstance(plugin_type, type): |
|
msg = ( |
|
f"Pass a callable `TypeIs` function to `plugin_type` instead.\n" |
|
f"{type(self).__name__!r}(plugin_type)\n\n" |
|
f"See also:\n" |
|
f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n" |
|
f"https://docs.astral.sh/ruff/rules/assert/" |
|
) |
|
deprecated_warn(msg, version="5.4.0") |
|
self.plugin_type = cast(IsPlugin, _is_type(plugin_type)) |
|
else: |
|
self.plugin_type = plugin_type |
|
self._active: Plugin[R] | None = None |
|
self._active_name: str = "" |
|
self._plugins: dict[str, PluginT] = {} |
|
self._options: dict[str, Any] = {} |
|
self._global_settings: dict[str, Any] = self.__class__._global_settings.copy() |
|
|
|
def register(self, name: str, value: PluginT | None) -> PluginT | None: |
|
""" |
|
Register a plugin by name and value. |
|
|
|
This method is used for explicit registration of a plugin and shouldn't be |
|
used to manage entry point managed plugins, which are auto-loaded. |
|
|
|
Parameters |
|
---------- |
|
name: str |
|
The name of the plugin. |
|
value: PluginType or None |
|
The actual plugin object to register or None to unregister that plugin. |
|
|
|
Returns |
|
------- |
|
plugin: PluginType or None |
|
The plugin that was registered or unregistered. |
|
""" |
|
if value is None: |
|
return self._plugins.pop(name, None) |
|
elif self.plugin_type(value): |
|
self._plugins[name] = value |
|
return value |
|
else: |
|
msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}" |
|
raise TypeError(msg) |
|
|
|
def names(self) -> list[str]: |
|
"""List the names of the registered and entry points plugins.""" |
|
exts = list(self._plugins.keys()) |
|
e_points = importlib_metadata_get(self.entry_point_group) |
|
more_exts = [ep.name for ep in e_points] |
|
exts.extend(more_exts) |
|
return sorted(set(exts)) |
|
|
|
def _get_state(self) -> dict[str, Any]: |
|
"""Return a dictionary representing the current state of the registry.""" |
|
return { |
|
"_active": self._active, |
|
"_active_name": self._active_name, |
|
"_plugins": self._plugins.copy(), |
|
"_options": self._options.copy(), |
|
"_global_settings": self._global_settings.copy(), |
|
} |
|
|
|
def _set_state(self, state: dict[str, Any]) -> None: |
|
"""Reset the state of the registry.""" |
|
assert set(state.keys()) == { |
|
"_active", |
|
"_active_name", |
|
"_plugins", |
|
"_options", |
|
"_global_settings", |
|
} |
|
for key, val in state.items(): |
|
setattr(self, key, val) |
|
|
|
def _enable(self, name: str, **options) -> None: |
|
if name not in self._plugins: |
|
try: |
|
(ep,) = ( |
|
ep |
|
for ep in importlib_metadata_get(self.entry_point_group) |
|
if ep.name == name |
|
) |
|
except ValueError as err: |
|
if name in self.entrypoint_err_messages: |
|
raise ValueError(self.entrypoint_err_messages[name]) from err |
|
else: |
|
raise NoSuchEntryPoint(self.entry_point_group, name) from err |
|
value = cast(PluginT, ep.load()) |
|
self.register(name, value) |
|
self._active_name = name |
|
self._active = self._plugins[name] |
|
for key in set(options.keys()) & set(self._global_settings.keys()): |
|
self._global_settings[key] = options.pop(key) |
|
self._options = options |
|
|
|
def enable(self, name: str | None = None, **options) -> PluginEnabler: |
|
""" |
|
Enable a plugin by name. |
|
|
|
This can be either called directly, or used as a context manager. |
|
|
|
Parameters |
|
---------- |
|
name : string (optional) |
|
The name of the plugin to enable. If not specified, then use the |
|
current active name. |
|
**options : |
|
Any additional parameters will be passed to the plugin as keyword |
|
arguments |
|
|
|
Returns |
|
------- |
|
PluginEnabler: |
|
An object that allows enable() to be used as a context manager |
|
""" |
|
if name is None: |
|
name = self.active |
|
return PluginEnabler(self, name, **options) |
|
|
|
@property |
|
def active(self) -> str: |
|
"""Return the name of the currently active plugin.""" |
|
return self._active_name |
|
|
|
@property |
|
def options(self) -> dict[str, Any]: |
|
"""Return the current options dictionary.""" |
|
return self._options |
|
|
|
def get(self) -> partial[R] | Plugin[R] | None: |
|
"""Return the currently active plugin.""" |
|
if (func := self._active) and self.plugin_type(func): |
|
return partial(func, **self._options) if self._options else func |
|
elif self._active is not None: |
|
msg = ( |
|
f"{type(self).__name__!r} requires all plugins to be callable objects, " |
|
f"but {type(self._active).__name__!r} is not callable." |
|
) |
|
raise TypeError(msg) |
|
elif TYPE_CHECKING: |
|
|
|
|
|
|
|
raise NotImplementedError |
|
|
|
def __repr__(self) -> str: |
|
return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})" |
|
|
|
|
|
def importlib_metadata_get(group): |
|
ep = entry_points() |
|
|
|
|
|
|
|
|
|
if hasattr(ep, "select"): |
|
return ep.select(group=group) |
|
else: |
|
return ep.get(group, []) |
|
|