Spaces:
Running
on
L40S
Running
on
L40S
File size: 3,909 Bytes
258fd02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
"""
Streaming module API that should be implemented by all Streaming components,
"""
from contextlib import contextmanager
import typing as tp
from torch import nn
import torch
State = tp.Dict[str, torch.Tensor]
class StreamingModule(nn.Module):
"""Common API for streaming components.
Each streaming component has a streaming state, which is just a dict[str, Tensor].
By convention, the first dim of each tensor must be the batch size.
Don't use dots in the key names, as this would clash with submodules
(like in state_dict).
If `self._is_streaming` is True, the component should use and remember
the proper state inside `self._streaming_state`.
To set a streaming component in streaming state, use
with module.streaming():
...
This will automatically reset the streaming state when exiting the context manager.
This also automatically propagates to all streaming children module.
Some module might also implement the `StreamingModule.flush` method, although
this one is trickier, as all parents module must be StreamingModule and implement
it as well for it to work properly. See `StreamingSequential` after.
"""
def __init__(self) -> None:
super().__init__()
self._streaming_state: State = {}
self._is_streaming = False
def _apply_named_streaming(self, fn: tp.Any):
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def _set_streaming(self, streaming: bool):
def _set_streaming(name, module):
module._is_streaming = streaming
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self):
"""Context manager to enter streaming mode. Reset streaming state on exit."""
self._set_streaming(True)
try:
yield
finally:
self._set_streaming(False)
self.reset_streaming()
def reset_streaming(self):
"""Reset the streaming state."""
def _reset(name: str, module: StreamingModule):
module._streaming_state.clear()
self._apply_named_streaming(_reset)
def get_streaming_state(self) -> State:
"""Return the streaming state, including that of sub-modules."""
state: State = {}
def _add(name: str, module: StreamingModule):
if name:
name += "."
for key, value in module._streaming_state.items():
state[name + key] = value
self._apply_named_streaming(_add)
return state
def set_streaming_state(self, state: State):
"""Set the streaming state, including that of sub-modules."""
state = dict(state)
def _set(name: str, module: StreamingModule):
if name:
name += "."
module._streaming_state.clear()
for key, value in list(state.items()):
# complexity is not ideal here, but probably fine.
if key.startswith(name):
local_key = key[len(name):]
if '.' not in local_key:
module._streaming_state[local_key] = value
del state[key]
self._apply_named_streaming(_set)
assert len(state) == 0, list(state.keys())
def flush(self, x: tp.Optional[torch.Tensor] = None):
"""Flush any remaining outputs that were waiting for completion.
Typically, for convolutions, this will add the final padding
and process the last buffer.
This should take an optional argument `x`, which will be provided
if a module before this one in the streaming pipeline has already
spitted out a flushed out buffer.
"""
if x is None:
return None
else:
return self(x) |