Spaces:
Building
Building
File size: 5,846 Bytes
f5f3483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cell auto-reload."""
from __future__ import annotations
import functools
import importlib
import inspect
import sys
from etils import epath
from etils.ecolab import adhoc_imports
from etils.ecolab import ip_utils
from etils.ecolab.adhoc_lib import reload_workspace_lib
from etils.epy.adhoc_utils import module_utils
def _create_module_graph(nodes: set[str]) -> dict[str, set[str]]:
graph = {}
for source in nodes:
deps = set()
for val in sys.modules[source].__dict__.values():
if inspect.ismodule(val) and val.__name__ in nodes:
deps.add(val.__name__)
graph[source] = deps
return graph
class _ModuleSearch:
"""Graph of module dependencies that can be queried."""
def __init__(self, targets: set[str], graph: dict[str, set[str]]):
self._graph = graph
self._cache: dict[str, bool] = {}
self._targets = targets
def _reaches_targets(self, source: str) -> bool:
"""Check if a module references other modules directly or indirectly."""
queue = [source]
visited = set(queue)
while queue:
m = queue.pop(0)
if (reaches := self._cache.get(m)) is not None:
if reaches:
# If m is known to reach target -> source reaches targets
return True
else:
# Otherwise, no need to search the neighbours of this node either.
continue
if m in self._targets:
return True
for neighbour in self._graph.get(m, set()):
if neighbour not in visited:
visited.add(neighbour)
queue.append(neighbour)
return False
def reaches_targets(self, source: str) -> bool:
"""Check if a module references other modules directly or indirectly."""
ret = self._reaches_targets(source)
self._cache[source] = ret
return ret
class ModuleReloader:
"""Module reloader."""
def __init__(self, **adhoc_kwargs):
self.adhoc_kwargs = adhoc_kwargs
self._last_updates: dict[str, int | None] = {}
@functools.cached_property
def reload(self) -> tuple[str, ...]:
return tuple(self.adhoc_kwargs['reload'])
@property
def verbose(self) -> bool:
return self.adhoc_kwargs['verbose']
def register(self) -> None:
if not self.reload:
raise ValueError('`cell_autoreload=True` require to set `reload=`')
# Keep a value for each module. If a file is updated, trigger a reload.
for module in module_utils.get_module_names(self.reload):
self._last_updates[module] = _get_last_module_update(module)
# Currently, only a single auto-reload can be set at the time.
# Probably a good idea as it's unclear how to differentiate between
# registering 2 cell_autoreload and overwriting cell_autoreload params.
ip_utils.register_once(
'pre_run_cell',
# Cannot use `self.method` because bound methods do not support
# set attribute.
functools.partial(type(self)._pre_run_cell_maybe_reload, self),
'is_cell_auto_reload',
)
def _pre_run_cell_maybe_reload(
self,
*args,
) -> None:
"""Check if workspace is modified, then eventually reload modules."""
del args # Future version of IPython will have a `info` arg
# TODO(epot): This function could be unified with `reload_workspace`
# If any of the modules has been updated, trigger a reload
# Find which modules are dirty.
dirty_modules: set[str] = set()
for module in module_utils.get_module_names(self.reload):
prev_mtime = self._last_updates.get(module)
new_mtime = _get_last_module_update(module)
if prev_mtime is None or (
new_mtime is not None and new_mtime > prev_mtime
):
dirty_modules.add(module)
self._last_updates[module] = new_mtime
if not dirty_modules:
return
# Get set of all modules we could potentially reload.
reload_set = set(module_utils.get_module_names(self.reload))
graph = _create_module_graph(reload_set)
search = _ModuleSearch(dirty_modules, graph)
# Narrow it down to modules that are dirty or reference a dirty module.
modules_to_reload = [
mod for mod in reload_set if search.reaches_targets(mod)
]
# Only reload exactly the modules we know are dirty. reload_recursive
# is an undocumented flag in adhoc for now.
adhoc_kwargs = self.adhoc_kwargs | {
'reload': modules_to_reload,
'reload_recursive': False,
'collapse_prefix': f'Autoreload ({len(modules_to_reload)} modules): ',
}
with adhoc_imports.adhoc(**adhoc_kwargs):
for module in modules_to_reload:
importlib.import_module(module)
# Update globals in user namespace with reloaded modules
reload_workspace_lib.update_global_namespace(
reload=modules_to_reload,
verbose=self.verbose,
)
def _get_last_module_update(module_name: str) -> int | None:
"""Get the last update for one module."""
module = sys.modules.get(module_name, None)
if module is None:
return None
if module.__name__ == '__main__':
return None
module_file = getattr(module, '__file__', None)
if not module_file:
return None
module_file = epath.Path(module_file)
try:
return module_file.stat().mtime
except OSError:
return None
|