PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cell auto-reload."""
from __future__ import annotations
import functools
import importlib
import inspect
import sys
from etils import epath
from etils.ecolab import adhoc_imports
from etils.ecolab import ip_utils
from etils.ecolab.adhoc_lib import reload_workspace_lib
from etils.epy.adhoc_utils import module_utils
def _create_module_graph(nodes: set[str]) -> dict[str, set[str]]:
graph = {}
for source in nodes:
deps = set()
for val in sys.modules[source].__dict__.values():
if inspect.ismodule(val) and val.__name__ in nodes:
deps.add(val.__name__)
graph[source] = deps
return graph
class _ModuleSearch:
"""Graph of module dependencies that can be queried."""
def __init__(self, targets: set[str], graph: dict[str, set[str]]):
self._graph = graph
self._cache: dict[str, bool] = {}
self._targets = targets
def _reaches_targets(self, source: str) -> bool:
"""Check if a module references other modules directly or indirectly."""
queue = [source]
visited = set(queue)
while queue:
m = queue.pop(0)
if (reaches := self._cache.get(m)) is not None:
if reaches:
# If m is known to reach target -> source reaches targets
return True
else:
# Otherwise, no need to search the neighbours of this node either.
continue
if m in self._targets:
return True
for neighbour in self._graph.get(m, set()):
if neighbour not in visited:
visited.add(neighbour)
queue.append(neighbour)
return False
def reaches_targets(self, source: str) -> bool:
"""Check if a module references other modules directly or indirectly."""
ret = self._reaches_targets(source)
self._cache[source] = ret
return ret
class ModuleReloader:
"""Module reloader."""
def __init__(self, **adhoc_kwargs):
self.adhoc_kwargs = adhoc_kwargs
self._last_updates: dict[str, int | None] = {}
@functools.cached_property
def reload(self) -> tuple[str, ...]:
return tuple(self.adhoc_kwargs['reload'])
@property
def verbose(self) -> bool:
return self.adhoc_kwargs['verbose']
def register(self) -> None:
if not self.reload:
raise ValueError('`cell_autoreload=True` require to set `reload=`')
# Keep a value for each module. If a file is updated, trigger a reload.
for module in module_utils.get_module_names(self.reload):
self._last_updates[module] = _get_last_module_update(module)
# Currently, only a single auto-reload can be set at the time.
# Probably a good idea as it's unclear how to differentiate between
# registering 2 cell_autoreload and overwriting cell_autoreload params.
ip_utils.register_once(
'pre_run_cell',
# Cannot use `self.method` because bound methods do not support
# set attribute.
functools.partial(type(self)._pre_run_cell_maybe_reload, self),
'is_cell_auto_reload',
)
def _pre_run_cell_maybe_reload(
self,
*args,
) -> None:
"""Check if workspace is modified, then eventually reload modules."""
del args # Future version of IPython will have a `info` arg
# TODO(epot): This function could be unified with `reload_workspace`
# If any of the modules has been updated, trigger a reload
# Find which modules are dirty.
dirty_modules: set[str] = set()
for module in module_utils.get_module_names(self.reload):
prev_mtime = self._last_updates.get(module)
new_mtime = _get_last_module_update(module)
if prev_mtime is None or (
new_mtime is not None and new_mtime > prev_mtime
):
dirty_modules.add(module)
self._last_updates[module] = new_mtime
if not dirty_modules:
return
# Get set of all modules we could potentially reload.
reload_set = set(module_utils.get_module_names(self.reload))
graph = _create_module_graph(reload_set)
search = _ModuleSearch(dirty_modules, graph)
# Narrow it down to modules that are dirty or reference a dirty module.
modules_to_reload = [
mod for mod in reload_set if search.reaches_targets(mod)
]
# Only reload exactly the modules we know are dirty. reload_recursive
# is an undocumented flag in adhoc for now.
adhoc_kwargs = self.adhoc_kwargs | {
'reload': modules_to_reload,
'reload_recursive': False,
'collapse_prefix': f'Autoreload ({len(modules_to_reload)} modules): ',
}
with adhoc_imports.adhoc(**adhoc_kwargs):
for module in modules_to_reload:
importlib.import_module(module)
# Update globals in user namespace with reloaded modules
reload_workspace_lib.update_global_namespace(
reload=modules_to_reload,
verbose=self.verbose,
)
def _get_last_module_update(module_name: str) -> int | None:
"""Get the last update for one module."""
module = sys.modules.get(module_name, None)
if module is None:
return None
if module.__name__ == '__main__':
return None
module_file = getattr(module, '__file__', None)
if not module_file:
return None
module_file = epath.Path(module_file)
try:
return module_file.stat().mtime
except OSError:
return None