|
"""The extension manager.""" |
|
|
|
from __future__ import annotations |
|
|
|
import importlib |
|
from itertools import starmap |
|
|
|
from tornado.gen import multi |
|
from traitlets import Any, Bool, Dict, HasTraits, Instance, List, Unicode, default, observe |
|
from traitlets import validate as validate_trait |
|
from traitlets.config import LoggingConfigurable |
|
|
|
from .config import ExtensionConfigManager |
|
from .utils import ExtensionMetadataError, ExtensionModuleNotFound, get_loader, get_metadata |
|
|
|
|
|
class ExtensionPoint(HasTraits): |
|
"""A simple API for connecting to a Jupyter Server extension |
|
point defined by metadata and importable from a Python package. |
|
""" |
|
|
|
_linked = Bool(False) |
|
_app = Any(None, allow_none=True) |
|
|
|
metadata = Dict() |
|
|
|
@validate_trait("metadata") |
|
def _valid_metadata(self, proposed): |
|
"""Validate metadata.""" |
|
metadata = proposed["value"] |
|
|
|
try: |
|
self._module_name = metadata["module"] |
|
except KeyError: |
|
msg = "There is no 'module' key in the extension's metadata packet." |
|
raise ExtensionMetadataError(msg) from None |
|
|
|
try: |
|
self._module = importlib.import_module(self._module_name) |
|
except ImportError: |
|
msg = ( |
|
f"The submodule '{self._module_name}' could not be found. Are you " |
|
"sure the extension is installed?" |
|
) |
|
raise ExtensionModuleNotFound(msg) from None |
|
|
|
if "app" in metadata: |
|
self._app = metadata["app"]() |
|
return metadata |
|
|
|
@property |
|
def linked(self): |
|
"""Has this extension point been linked to the server. |
|
|
|
Will pull from ExtensionApp's trait, if this point |
|
is an instance of ExtensionApp. |
|
""" |
|
if self.app: |
|
return self.app._linked |
|
return self._linked |
|
|
|
@property |
|
def app(self): |
|
"""If the metadata includes an `app` field""" |
|
return self._app |
|
|
|
@property |
|
def config(self): |
|
"""Return any configuration provided by this extension point.""" |
|
if self.app: |
|
return self.app._jupyter_server_config() |
|
|
|
|
|
else: |
|
return {} |
|
|
|
@property |
|
def module_name(self): |
|
"""Name of the Python package module where the extension's |
|
_load_jupyter_server_extension can be found. |
|
""" |
|
return self._module_name |
|
|
|
@property |
|
def name(self): |
|
"""Name of the extension. |
|
|
|
If it's not provided in the metadata, `name` is set |
|
to the extensions' module name. |
|
""" |
|
if self.app: |
|
return self.app.name |
|
return self.metadata.get("name", self.module_name) |
|
|
|
@property |
|
def module(self): |
|
"""The imported module (using importlib.import_module)""" |
|
return self._module |
|
|
|
def _get_linker(self): |
|
"""Get a linker.""" |
|
if self.app: |
|
linker = self.app._link_jupyter_server_extension |
|
else: |
|
linker = getattr( |
|
self.module, |
|
|
|
"_link_jupyter_server_extension", |
|
|
|
lambda serverapp: None, |
|
) |
|
return linker |
|
|
|
def _get_loader(self): |
|
"""Get a loader.""" |
|
loc = self.app |
|
if not loc: |
|
loc = self.module |
|
loader = get_loader(loc) |
|
return loader |
|
|
|
def validate(self): |
|
"""Check that both a linker and loader exists.""" |
|
try: |
|
self._get_linker() |
|
self._get_loader() |
|
except Exception: |
|
return False |
|
else: |
|
return True |
|
|
|
def link(self, serverapp): |
|
"""Link the extension to a Jupyter ServerApp object. |
|
|
|
This looks for a `_link_jupyter_server_extension` function |
|
in the extension's module or ExtensionApp class. |
|
""" |
|
if not self.linked: |
|
linker = self._get_linker() |
|
linker(serverapp) |
|
|
|
self._linked = True |
|
|
|
def load(self, serverapp): |
|
"""Load the extension in a Jupyter ServerApp object. |
|
|
|
This looks for a `_load_jupyter_server_extension` function |
|
in the extension's module or ExtensionApp class. |
|
""" |
|
loader = self._get_loader() |
|
return loader(serverapp) |
|
|
|
|
|
class ExtensionPackage(LoggingConfigurable): |
|
"""An API for interfacing with a Jupyter Server extension package. |
|
|
|
Usage: |
|
|
|
ext_name = "my_extensions" |
|
extpkg = ExtensionPackage(name=ext_name) |
|
""" |
|
|
|
name = Unicode(help="Name of the an importable Python package.") |
|
enabled = Bool(False, help="Whether the extension package is enabled.") |
|
|
|
_linked_points = Dict() |
|
extension_points = Dict() |
|
module = Any(allow_none=True, help="The module for this extension package. None if not enabled") |
|
metadata = List(Dict(), help="Extension metadata loaded from the extension package.") |
|
version = Unicode( |
|
help=""" |
|
The version of this extension package, if it can be found. |
|
Otherwise, an empty string. |
|
""", |
|
) |
|
|
|
@default("version") |
|
def _load_version(self): |
|
if not self.enabled: |
|
return "" |
|
return getattr(self.module, "__version__", "") |
|
|
|
def __init__(self, **kwargs): |
|
"""Initialize an extension package.""" |
|
super().__init__(**kwargs) |
|
if self.enabled: |
|
self._load_metadata() |
|
|
|
def _load_metadata(self): |
|
"""Import package and load metadata |
|
|
|
Only used if extension package is enabled |
|
""" |
|
name = self.name |
|
try: |
|
self.module, self.metadata = get_metadata(name, logger=self.log) |
|
except ImportError as e: |
|
msg = ( |
|
f"The module '{name}' could not be found ({e}). Are you " |
|
"sure the extension is installed?" |
|
) |
|
raise ExtensionModuleNotFound(msg) from None |
|
|
|
for m in self.metadata: |
|
point = ExtensionPoint(metadata=m) |
|
self.extension_points[point.name] = point |
|
return name |
|
|
|
def validate(self): |
|
"""Validate all extension points in this package.""" |
|
return all(extension.validate() for extension in self.extension_points.values()) |
|
|
|
def link_point(self, point_name, serverapp): |
|
"""Link an extension point.""" |
|
linked = self._linked_points.get(point_name, False) |
|
if not linked: |
|
point = self.extension_points[point_name] |
|
point.link(serverapp) |
|
|
|
def load_point(self, point_name, serverapp): |
|
"""Load an extension point.""" |
|
point = self.extension_points[point_name] |
|
return point.load(serverapp) |
|
|
|
def link_all_points(self, serverapp): |
|
"""Link all extension points.""" |
|
for point_name in self.extension_points: |
|
self.link_point(point_name, serverapp) |
|
|
|
def load_all_points(self, serverapp): |
|
"""Load all extension points.""" |
|
return [self.load_point(point_name, serverapp) for point_name in self.extension_points] |
|
|
|
|
|
class ExtensionManager(LoggingConfigurable): |
|
"""High level interface for findind, validating, |
|
linking, loading, and managing Jupyter Server extensions. |
|
|
|
Usage: |
|
m = ExtensionManager(config_manager=...) |
|
""" |
|
|
|
config_manager = Instance(ExtensionConfigManager, allow_none=True) |
|
|
|
serverapp = Any() |
|
|
|
@default("config_manager") |
|
def _load_default_config_manager(self): |
|
config_manager = ExtensionConfigManager() |
|
self._load_config_manager(config_manager) |
|
return config_manager |
|
|
|
@observe("config_manager") |
|
def _config_manager_changed(self, change): |
|
if change.new: |
|
self._load_config_manager(change.new) |
|
|
|
|
|
|
|
|
|
|
|
extensions = Dict( |
|
help=""" |
|
Dictionary with extension package names as keys |
|
and ExtensionPackage objects as values. |
|
""" |
|
) |
|
|
|
@property |
|
def sorted_extensions(self): |
|
"""Returns an extensions dictionary, sorted alphabetically.""" |
|
return dict(sorted(self.extensions.items())) |
|
|
|
|
|
|
|
|
|
|
|
linked_extensions = Dict( |
|
help=""" |
|
Dictionary with extension names as keys |
|
|
|
values are True if the extension is linked, False if not. |
|
""" |
|
) |
|
|
|
@property |
|
def extension_apps(self): |
|
"""Return mapping of extension names and sets of ExtensionApp objects.""" |
|
return { |
|
name: {point.app for point in extension.extension_points.values() if point.app} |
|
for name, extension in self.extensions.items() |
|
} |
|
|
|
@property |
|
def extension_points(self): |
|
"""Return mapping of extension point names and ExtensionPoint objects.""" |
|
return { |
|
name: point |
|
for value in self.extensions.values() |
|
for name, point in value.extension_points.items() |
|
} |
|
|
|
def from_config_manager(self, config_manager): |
|
"""Add extensions found by an ExtensionConfigManager""" |
|
|
|
self.config_manager = config_manager |
|
|
|
def _load_config_manager(self, config_manager): |
|
"""Actually load our config manager""" |
|
jpserver_extensions = config_manager.get_jpserver_extensions() |
|
self.from_jpserver_extensions(jpserver_extensions) |
|
|
|
def from_jpserver_extensions(self, jpserver_extensions): |
|
"""Add extensions from 'jpserver_extensions'-like dictionary.""" |
|
for name, enabled in jpserver_extensions.items(): |
|
self.add_extension(name, enabled=enabled) |
|
|
|
def add_extension(self, extension_name, enabled=False): |
|
"""Try to add extension to manager, return True if successful. |
|
Otherwise, return False. |
|
""" |
|
try: |
|
extpkg = ExtensionPackage(name=extension_name, enabled=enabled) |
|
self.extensions[extension_name] = extpkg |
|
return True |
|
|
|
except Exception as e: |
|
if self.serverapp and self.serverapp.reraise_server_extension_failures: |
|
raise |
|
self.log.warning( |
|
"%s | error adding extension (enabled: %s): %s", |
|
extension_name, |
|
enabled, |
|
e, |
|
exc_info=True, |
|
) |
|
return False |
|
|
|
def link_extension(self, name): |
|
"""Link an extension by name.""" |
|
linked = self.linked_extensions.get(name, False) |
|
extension = self.extensions[name] |
|
if not linked and extension.enabled: |
|
try: |
|
|
|
extension.link_all_points(self.serverapp) |
|
self.linked_extensions[name] = True |
|
self.log.info("%s | extension was successfully linked.", name) |
|
except Exception as e: |
|
if self.serverapp and self.serverapp.reraise_server_extension_failures: |
|
raise |
|
self.log.warning("%s | error linking extension: %s", name, e, exc_info=True) |
|
|
|
def load_extension(self, name): |
|
"""Load an extension by name.""" |
|
extension = self.extensions.get(name) |
|
|
|
if extension and extension.enabled: |
|
try: |
|
extension.load_all_points(self.serverapp) |
|
except Exception as e: |
|
if self.serverapp and self.serverapp.reraise_server_extension_failures: |
|
raise |
|
self.log.warning( |
|
"%s | extension failed loading with message: %r", name, e, exc_info=True |
|
) |
|
else: |
|
self.log.info("%s | extension was successfully loaded.", name) |
|
|
|
async def stop_extension(self, name, apps): |
|
"""Call the shutdown hooks in the specified apps.""" |
|
for app in apps: |
|
self.log.debug("%s | extension app %r stopping", name, app.name) |
|
await app.stop_extension() |
|
self.log.debug("%s | extension app %r stopped", name, app.name) |
|
|
|
def link_all_extensions(self): |
|
"""Link all enabled extensions |
|
to an instance of ServerApp |
|
""" |
|
|
|
|
|
for name in self.sorted_extensions: |
|
self.link_extension(name) |
|
|
|
def load_all_extensions(self): |
|
"""Load all enabled extensions and append them to |
|
the parent ServerApp. |
|
""" |
|
|
|
|
|
for name in self.sorted_extensions: |
|
self.load_extension(name) |
|
|
|
async def stop_all_extensions(self): |
|
"""Call the shutdown hooks in all extensions.""" |
|
await multi(list(starmap(self.stop_extension, sorted(dict(self.extension_apps).items())))) |
|
|
|
def any_activity(self): |
|
"""Check for any activity currently happening across all extension applications.""" |
|
for _, apps in sorted(dict(self.extension_apps).items()): |
|
for app in apps: |
|
if app.current_activity(): |
|
return True |
|
|