File size: 5,846 Bytes
f5f3483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Cell auto-reload."""

from __future__ import annotations

import functools
import importlib
import inspect
import sys

from etils import epath
from etils.ecolab import adhoc_imports
from etils.ecolab import ip_utils
from etils.ecolab.adhoc_lib import reload_workspace_lib
from etils.epy.adhoc_utils import module_utils


def _create_module_graph(nodes: set[str]) -> dict[str, set[str]]:
  graph = {}

  for source in nodes:
    deps = set()
    for val in sys.modules[source].__dict__.values():
      if inspect.ismodule(val) and val.__name__ in nodes:
        deps.add(val.__name__)
    graph[source] = deps

  return graph


class _ModuleSearch:
  """Graph of module dependencies that can be queried."""

  def __init__(self, targets: set[str], graph: dict[str, set[str]]):
    self._graph = graph
    self._cache: dict[str, bool] = {}
    self._targets = targets

  def _reaches_targets(self, source: str) -> bool:
    """Check if a module references other modules directly or indirectly."""
    queue = [source]
    visited = set(queue)

    while queue:
      m = queue.pop(0)

      if (reaches := self._cache.get(m)) is not None:
        if reaches:
          # If m is known to reach target -> source reaches targets
          return True
        else:
          # Otherwise, no need to search the neighbours of this node either.
          continue

      if m in self._targets:
        return True

      for neighbour in self._graph.get(m, set()):
        if neighbour not in visited:
          visited.add(neighbour)
          queue.append(neighbour)

    return False

  def reaches_targets(self, source: str) -> bool:
    """Check if a module references other modules directly or indirectly."""
    ret = self._reaches_targets(source)
    self._cache[source] = ret
    return ret


class ModuleReloader:
  """Module reloader."""

  def __init__(self, **adhoc_kwargs):
    self.adhoc_kwargs = adhoc_kwargs
    self._last_updates: dict[str, int | None] = {}

  @functools.cached_property
  def reload(self) -> tuple[str, ...]:
    return tuple(self.adhoc_kwargs['reload'])

  @property
  def verbose(self) -> bool:
    return self.adhoc_kwargs['verbose']

  def register(self) -> None:
    if not self.reload:
      raise ValueError('`cell_autoreload=True` require to set `reload=`')

    # Keep a value for each module. If a file is updated, trigger a reload.
    for module in module_utils.get_module_names(self.reload):
      self._last_updates[module] = _get_last_module_update(module)

    # Currently, only a single auto-reload can be set at the time.
    # Probably a good idea as it's unclear how to differentiate between
    # registering 2 cell_autoreload and overwriting cell_autoreload params.
    ip_utils.register_once(
        'pre_run_cell',
        # Cannot use `self.method` because bound methods do not support
        # set attribute.
        functools.partial(type(self)._pre_run_cell_maybe_reload, self),
        'is_cell_auto_reload',
    )

  def _pre_run_cell_maybe_reload(
      self,
      *args,
  ) -> None:
    """Check if workspace is modified, then eventually reload modules."""
    del args  # Future version of IPython will have a `info` arg

    # TODO(epot): This function could be unified with `reload_workspace`

    # If any of the modules has been updated, trigger a reload

    # Find which modules are dirty.
    dirty_modules: set[str] = set()
    for module in module_utils.get_module_names(self.reload):
      prev_mtime = self._last_updates.get(module)
      new_mtime = _get_last_module_update(module)
      if prev_mtime is None or (
          new_mtime is not None and new_mtime > prev_mtime
      ):
        dirty_modules.add(module)
      self._last_updates[module] = new_mtime

    if not dirty_modules:
      return

    # Get set of all modules we could potentially reload.
    reload_set = set(module_utils.get_module_names(self.reload))
    graph = _create_module_graph(reload_set)
    search = _ModuleSearch(dirty_modules, graph)

    # Narrow it down to modules that are dirty or reference a dirty module.
    modules_to_reload = [
        mod for mod in reload_set if search.reaches_targets(mod)
    ]

    # Only reload exactly the modules we know are dirty. reload_recursive
    # is an undocumented flag in adhoc for now.
    adhoc_kwargs = self.adhoc_kwargs | {
        'reload': modules_to_reload,
        'reload_recursive': False,
        'collapse_prefix': f'Autoreload ({len(modules_to_reload)} modules): ',
    }
    with adhoc_imports.adhoc(**adhoc_kwargs):
      for module in modules_to_reload:
        importlib.import_module(module)

      # Update globals in user namespace with reloaded modules
      reload_workspace_lib.update_global_namespace(
          reload=modules_to_reload,
          verbose=self.verbose,
      )


def _get_last_module_update(module_name: str) -> int | None:
  """Get the last update for one module."""
  module = sys.modules.get(module_name, None)
  if module is None:
    return None
  if module.__name__ == '__main__':
    return None

  module_file = getattr(module, '__file__', None)
  if not module_file:
    return None

  module_file = epath.Path(module_file)

  try:
    return module_file.stat().mtime
  except OSError:
    return None