File size: 644 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
# mypy: allow-untyped-defs
import contextlib
from typing import Callable, List, TYPE_CHECKING
if TYPE_CHECKING:
import torch
# Executed in the order they're registered
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
@contextlib.contextmanager
def intermediate_hook(fn):
INTERMEDIATE_HOOKS.append(fn)
try:
yield
finally:
INTERMEDIATE_HOOKS.pop()
def run_intermediate_hooks(name, val):
global INTERMEDIATE_HOOKS
hooks = INTERMEDIATE_HOOKS
INTERMEDIATE_HOOKS = []
try:
for hook in hooks:
hook(name, val)
finally:
INTERMEDIATE_HOOKS = hooks
|