""" Streaming module API that should be implemented by all Streaming components, """ from contextlib import contextmanager import typing as tp from torch import nn import torch State = tp.Dict[str, torch.Tensor] class StreamingModule(nn.Module): """Common API for streaming components. Each streaming component has a streaming state, which is just a dict[str, Tensor]. By convention, the first dim of each tensor must be the batch size. Don't use dots in the key names, as this would clash with submodules (like in state_dict). If `self._is_streaming` is True, the component should use and remember the proper state inside `self._streaming_state`. To set a streaming component in streaming state, use with module.streaming(): ... This will automatically reset the streaming state when exiting the context manager. This also automatically propagates to all streaming children module. Some module might also implement the `StreamingModule.flush` method, although this one is trickier, as all parents module must be StreamingModule and implement it as well for it to work properly. See `StreamingSequential` after. """ def __init__(self) -> None: super().__init__() self._streaming_state: State = {} self._is_streaming = False def _apply_named_streaming(self, fn: tp.Any): for name, module in self.named_modules(): if isinstance(module, StreamingModule): fn(name, module) def _set_streaming(self, streaming: bool): def _set_streaming(name, module): module._is_streaming = streaming self._apply_named_streaming(_set_streaming) @contextmanager def streaming(self): """Context manager to enter streaming mode. Reset streaming state on exit.""" self._set_streaming(True) try: yield finally: self._set_streaming(False) self.reset_streaming() def reset_streaming(self): """Reset the streaming state.""" def _reset(name: str, module: StreamingModule): module._streaming_state.clear() self._apply_named_streaming(_reset) def get_streaming_state(self) -> State: """Return the streaming state, including that of sub-modules.""" state: State = {} def _add(name: str, module: StreamingModule): if name: name += "." for key, value in module._streaming_state.items(): state[name + key] = value self._apply_named_streaming(_add) return state def set_streaming_state(self, state: State): """Set the streaming state, including that of sub-modules.""" state = dict(state) def _set(name: str, module: StreamingModule): if name: name += "." module._streaming_state.clear() for key, value in list(state.items()): # complexity is not ideal here, but probably fine. if key.startswith(name): local_key = key[len(name):] if '.' not in local_key: module._streaming_state[local_key] = value del state[key] self._apply_named_streaming(_set) assert len(state) == 0, list(state.keys()) def flush(self, x: tp.Optional[torch.Tensor] = None): """Flush any remaining outputs that were waiting for completion. Typically, for convolutions, this will add the final padding and process the last buffer. This should take an optional argument `x`, which will be provided if a module before this one in the streaming pipeline has already spitted out a flushed out buffer. """ if x is None: return None else: return self(x)