Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Cell auto-reload.""" | |
from __future__ import annotations | |
import functools | |
import importlib | |
import inspect | |
import sys | |
from etils import epath | |
from etils.ecolab import adhoc_imports | |
from etils.ecolab import ip_utils | |
from etils.ecolab.adhoc_lib import reload_workspace_lib | |
from etils.epy.adhoc_utils import module_utils | |
def _create_module_graph(nodes: set[str]) -> dict[str, set[str]]: | |
graph = {} | |
for source in nodes: | |
deps = set() | |
for val in sys.modules[source].__dict__.values(): | |
if inspect.ismodule(val) and val.__name__ in nodes: | |
deps.add(val.__name__) | |
graph[source] = deps | |
return graph | |
class _ModuleSearch: | |
"""Graph of module dependencies that can be queried.""" | |
def __init__(self, targets: set[str], graph: dict[str, set[str]]): | |
self._graph = graph | |
self._cache: dict[str, bool] = {} | |
self._targets = targets | |
def _reaches_targets(self, source: str) -> bool: | |
"""Check if a module references other modules directly or indirectly.""" | |
queue = [source] | |
visited = set(queue) | |
while queue: | |
m = queue.pop(0) | |
if (reaches := self._cache.get(m)) is not None: | |
if reaches: | |
# If m is known to reach target -> source reaches targets | |
return True | |
else: | |
# Otherwise, no need to search the neighbours of this node either. | |
continue | |
if m in self._targets: | |
return True | |
for neighbour in self._graph.get(m, set()): | |
if neighbour not in visited: | |
visited.add(neighbour) | |
queue.append(neighbour) | |
return False | |
def reaches_targets(self, source: str) -> bool: | |
"""Check if a module references other modules directly or indirectly.""" | |
ret = self._reaches_targets(source) | |
self._cache[source] = ret | |
return ret | |
class ModuleReloader: | |
"""Module reloader.""" | |
def __init__(self, **adhoc_kwargs): | |
self.adhoc_kwargs = adhoc_kwargs | |
self._last_updates: dict[str, int | None] = {} | |
def reload(self) -> tuple[str, ...]: | |
return tuple(self.adhoc_kwargs['reload']) | |
def verbose(self) -> bool: | |
return self.adhoc_kwargs['verbose'] | |
def register(self) -> None: | |
if not self.reload: | |
raise ValueError('`cell_autoreload=True` require to set `reload=`') | |
# Keep a value for each module. If a file is updated, trigger a reload. | |
for module in module_utils.get_module_names(self.reload): | |
self._last_updates[module] = _get_last_module_update(module) | |
# Currently, only a single auto-reload can be set at the time. | |
# Probably a good idea as it's unclear how to differentiate between | |
# registering 2 cell_autoreload and overwriting cell_autoreload params. | |
ip_utils.register_once( | |
'pre_run_cell', | |
# Cannot use `self.method` because bound methods do not support | |
# set attribute. | |
functools.partial(type(self)._pre_run_cell_maybe_reload, self), | |
'is_cell_auto_reload', | |
) | |
def _pre_run_cell_maybe_reload( | |
self, | |
*args, | |
) -> None: | |
"""Check if workspace is modified, then eventually reload modules.""" | |
del args # Future version of IPython will have a `info` arg | |
# TODO(epot): This function could be unified with `reload_workspace` | |
# If any of the modules has been updated, trigger a reload | |
# Find which modules are dirty. | |
dirty_modules: set[str] = set() | |
for module in module_utils.get_module_names(self.reload): | |
prev_mtime = self._last_updates.get(module) | |
new_mtime = _get_last_module_update(module) | |
if prev_mtime is None or ( | |
new_mtime is not None and new_mtime > prev_mtime | |
): | |
dirty_modules.add(module) | |
self._last_updates[module] = new_mtime | |
if not dirty_modules: | |
return | |
# Get set of all modules we could potentially reload. | |
reload_set = set(module_utils.get_module_names(self.reload)) | |
graph = _create_module_graph(reload_set) | |
search = _ModuleSearch(dirty_modules, graph) | |
# Narrow it down to modules that are dirty or reference a dirty module. | |
modules_to_reload = [ | |
mod for mod in reload_set if search.reaches_targets(mod) | |
] | |
# Only reload exactly the modules we know are dirty. reload_recursive | |
# is an undocumented flag in adhoc for now. | |
adhoc_kwargs = self.adhoc_kwargs | { | |
'reload': modules_to_reload, | |
'reload_recursive': False, | |
'collapse_prefix': f'Autoreload ({len(modules_to_reload)} modules): ', | |
} | |
with adhoc_imports.adhoc(**adhoc_kwargs): | |
for module in modules_to_reload: | |
importlib.import_module(module) | |
# Update globals in user namespace with reloaded modules | |
reload_workspace_lib.update_global_namespace( | |
reload=modules_to_reload, | |
verbose=self.verbose, | |
) | |
def _get_last_module_update(module_name: str) -> int | None: | |
"""Get the last update for one module.""" | |
module = sys.modules.get(module_name, None) | |
if module is None: | |
return None | |
if module.__name__ == '__main__': | |
return None | |
module_file = getattr(module, '__file__', None) | |
if not module_file: | |
return None | |
module_file = epath.Path(module_file) | |
try: | |
return module_file.stat().mtime | |
except OSError: | |
return None | |