|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import collections |
|
from collections import OrderedDict |
|
|
|
import torch |
|
from torch.nn import Module |
|
|
|
|
|
class BufferDict(Module): |
|
r""" |
|
Holds buffers in a dictionary. |
|
|
|
BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and |
|
will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects |
|
|
|
* the order of insertion, and |
|
* in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the |
|
argument to `torch.nn.BufferDict.update`). |
|
|
|
Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not |
|
preserve the order of the merged mapping. |
|
|
|
Args: |
|
buffers (iterable, optional): |
|
a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string, |
|
`torch.Tensor`) |
|
|
|
```python |
|
class MyModule(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)}) |
|
|
|
def forward(self, x, choice): |
|
x = self.buffers[choice].mm(x) |
|
return x |
|
``` |
|
""" |
|
|
|
def __init__(self, buffers=None, persistent: bool = False): |
|
r""" |
|
Args: |
|
buffers (`dict`): |
|
A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type |
|
(string, `torch.Tensor`). |
|
""" |
|
super().__init__() |
|
if buffers is not None: |
|
self.update(buffers) |
|
|
|
self.persistent = persistent |
|
|
|
def __getitem__(self, key): |
|
return self._buffers[key] |
|
|
|
def __setitem__(self, key, buffer): |
|
self.register_buffer(key, buffer, persistent=self.persistent) |
|
|
|
def __delitem__(self, key): |
|
del self._buffers[key] |
|
|
|
def __len__(self): |
|
return len(self._buffers) |
|
|
|
def __iter__(self): |
|
return iter(self._buffers.keys()) |
|
|
|
def __contains__(self, key): |
|
return key in self._buffers |
|
|
|
def clear(self): |
|
"""Remove all items from the BufferDict.""" |
|
self._buffers.clear() |
|
|
|
def pop(self, key): |
|
r"""Remove key from the BufferDict and return its buffer. |
|
|
|
Args: |
|
key (`str`): |
|
Key to pop from the BufferDict |
|
""" |
|
v = self[key] |
|
del self[key] |
|
return v |
|
|
|
def keys(self): |
|
r"""Return an iterable of the BufferDict keys.""" |
|
return self._buffers.keys() |
|
|
|
def items(self): |
|
r"""Return an iterable of the BufferDict key/value pairs.""" |
|
return self._buffers.items() |
|
|
|
def values(self): |
|
r"""Return an iterable of the BufferDict values.""" |
|
return self._buffers.values() |
|
|
|
def update(self, buffers): |
|
r""" |
|
Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing |
|
keys. |
|
|
|
Note: |
|
If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of |
|
new elements in it is preserved. |
|
|
|
Args: |
|
buffers (iterable): |
|
a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type |
|
(string, `torch.Tensor`). |
|
""" |
|
if not isinstance(buffers, collections.abc.Iterable): |
|
raise TypeError( |
|
"BuffersDict.update should be called with an " |
|
"iterable of key/value pairs, but got " + type(buffers).__name__ |
|
) |
|
|
|
if isinstance(buffers, collections.abc.Mapping): |
|
if isinstance(buffers, (OrderedDict, BufferDict)): |
|
for key, buffer in buffers.items(): |
|
self[key] = buffer |
|
else: |
|
for key, buffer in sorted(buffers.items()): |
|
self[key] = buffer |
|
else: |
|
for j, p in enumerate(buffers): |
|
if not isinstance(p, collections.abc.Iterable): |
|
raise TypeError( |
|
"BufferDict update sequence element " |
|
"#" + str(j) + " should be Iterable; is" + type(p).__name__ |
|
) |
|
if not len(p) == 2: |
|
raise ValueError( |
|
"BufferDict update sequence element " |
|
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required" |
|
) |
|
self[p[0]] = p[1] |
|
|
|
def extra_repr(self): |
|
child_lines = [] |
|
for k, p in self._buffers.items(): |
|
size_str = "x".join(str(size) for size in p.size()) |
|
device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})" |
|
parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]" |
|
child_lines.append(" (" + k + "): " + parastr) |
|
tmpstr = "\n".join(child_lines) |
|
return tmpstr |
|
|
|
def __call__(self, input): |
|
raise RuntimeError("BufferDict should not be called.") |
|
|