|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import SymInt, Tensor |
|
from torch._C import _add_docstr, _nested |
|
|
|
from torch.types import _device as Device, _dtype as DType |
|
|
|
__all__ = [ |
|
"to_padded_tensor", |
|
"as_nested_tensor", |
|
"nested_tensor", |
|
"nested_tensor_from_jagged", |
|
"narrow", |
|
] |
|
|
|
|
|
|
|
|
|
def as_nested_tensor( |
|
ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], |
|
dtype: Optional[DType] = None, |
|
device: Optional[Device] = None, |
|
layout=None |
|
) -> Tensor: |
|
r""" |
|
Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of |
|
tensors. |
|
|
|
If a nested tensor is passed, it will be returned directly unless the device / dtype / layout |
|
differ. Note that converting device / dtype will result in a copy, while converting layout |
|
is not currently supported by this function. |
|
|
|
If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size. |
|
A copy will be incurred if the passed device / dtype differ from those of the input OR if |
|
the input is non-contiguous. Otherwise, the input's storage will be used directly. |
|
|
|
If a tensor list is provided, tensors in the list are always copied during construction of |
|
the nested tensor. |
|
|
|
Args: |
|
ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a |
|
list / tuple of tensors with the same ndim |
|
|
|
Keyword arguments: |
|
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. |
|
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. |
|
device (:class:`torch.device`, optional): the desired device of returned nested tensor. |
|
Default: if None, same :class:`torch.device` as leftmost tensor in the list |
|
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. |
|
Only strided and jagged layouts are supported. Default: if None, the strided layout. |
|
|
|
Example:: |
|
|
|
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True) |
|
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True) |
|
>>> nt = torch.nested.as_nested_tensor([a, b]) |
|
>>> nt.is_leaf |
|
False |
|
>>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) |
|
>>> nt.backward(fake_grad) |
|
>>> a.grad |
|
tensor([1., 1., 1.]) |
|
>>> b.grad |
|
tensor([0., 0., 0., 0., 0.]) |
|
>>> c = torch.randn(3, 5, requires_grad=True) |
|
>>> nt2 = torch.nested.as_nested_tensor(c) |
|
""" |
|
is_tensor_list = isinstance(ts, (list, tuple)) and all(isinstance(t, Tensor) for t in ts) |
|
if not isinstance(ts, Tensor) and not is_tensor_list: |
|
raise TypeError( |
|
"as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors " |
|
) |
|
|
|
if is_tensor_list and not isinstance(ts, list): |
|
ts = list(ts) |
|
|
|
if isinstance(ts, Tensor) and ts.dim() < 2: |
|
raise RuntimeError("as_nested_tensor(): Expected tensor argument to have dim() > 1") |
|
|
|
if isinstance(ts, Tensor) and ts.is_nested: |
|
if layout == ts.layout: |
|
|
|
return ts.to(device=device, dtype=dtype) |
|
else: |
|
|
|
raise RuntimeError( |
|
"as_nested_tensor(): Converting between nested tensor layouts is not supported") |
|
|
|
if layout is None: |
|
layout = torch.strided |
|
if layout == torch.strided: |
|
if isinstance(ts, Tensor): |
|
|
|
|
|
buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype) |
|
nested_sizes = torch.tensor([t.shape for t in ts]) |
|
return torch._nested_view_from_buffer( |
|
buffer, |
|
nested_sizes, |
|
*torch._nested_compute_contiguous_strides_offsets(nested_sizes)) |
|
else: |
|
assert isinstance(ts, list) |
|
return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None) |
|
elif layout == torch.jagged: |
|
if isinstance(ts, Tensor): |
|
|
|
|
|
values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype) |
|
batch_size = ts.shape[0] |
|
seq_len = ts.shape[1] |
|
offsets = torch.arange(0, batch_size * seq_len + 1, seq_len, |
|
device=device, dtype=torch.int64) |
|
|
|
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets |
|
|
|
return nested_view_from_values_offsets(values, offsets) |
|
else: |
|
from torch.nested._internal.nested_tensor import jagged_from_list |
|
|
|
assert isinstance(ts, list) |
|
nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype) |
|
return nt |
|
else: |
|
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") |
|
|
|
|
|
|
|
|
|
|
|
to_padded_tensor = _add_docstr( |
|
_nested.nested_to_padded_tensor, |
|
r""" |
|
to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor |
|
|
|
Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor. |
|
The leading entries will be filled with the nested data, |
|
while the trailing entries will be padded. |
|
|
|
.. warning:: |
|
|
|
:func:`to_padded_tensor` always copies the underlying data, |
|
since the nested and the non-nested tensors differ in memory layout. |
|
|
|
Args: |
|
padding (float): The padding value for the trailing entries. |
|
|
|
Keyword args: |
|
output_size (Tuple[int]): The size of the output tensor. |
|
If given, it must be large enough to contain all nested data; |
|
else, will infer by taking the max size of each nested sub-tensor along each dimension. |
|
out (Tensor, optional): the output tensor. |
|
|
|
Example:: |
|
|
|
>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]) |
|
nested_tensor([ |
|
tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], |
|
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]), |
|
tensor([[-1.8546, -0.7194, -0.2918, -0.1846], |
|
[ 0.2773, 0.8793, -0.5183, -0.6447], |
|
[ 1.8009, 1.8468, -0.9832, -1.5272]]) |
|
]) |
|
>>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0) |
|
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], |
|
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995], |
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], |
|
[[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000], |
|
[ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000], |
|
[ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]]) |
|
>>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6)) |
|
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000], |
|
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000], |
|
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], |
|
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], |
|
[[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000], |
|
[ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000], |
|
[ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000], |
|
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]) |
|
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2)) |
|
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported. |
|
|
|
""", |
|
) |
|
|
|
def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor: |
|
r""" |
|
Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see |
|
:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors. |
|
|
|
Args: |
|
tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor, |
|
where each element of the list has the same dimensionality. |
|
|
|
Keyword arguments: |
|
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. |
|
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. |
|
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. |
|
Only strided and jagged layouts are supported. Default: if None, the strided layout. |
|
device (:class:`torch.device`, optional): the desired device of returned nested tensor. |
|
Default: if None, same :class:`torch.device` as leftmost tensor in the list |
|
requires_grad (bool, optional): If autograd should record operations on the |
|
returned nested tensor. Default: ``False``. |
|
pin_memory (bool, optional): If set, returned nested tensor would be allocated in |
|
the pinned memory. Works only for CPU tensors. Default: ``False``. |
|
|
|
Example:: |
|
|
|
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True) |
|
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True) |
|
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True) |
|
>>> nt.is_leaf |
|
True |
|
""" |
|
if layout is None: |
|
layout = torch.strided |
|
if layout == torch.strided: |
|
return _nested.nested_tensor( |
|
tensor_list, |
|
dtype=dtype, |
|
device=device, |
|
requires_grad=requires_grad, |
|
pin_memory=pin_memory) |
|
elif layout == torch.jagged: |
|
|
|
list_of_tensors = [t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list] |
|
|
|
from torch.nested._internal.nested_tensor import jagged_from_list |
|
|
|
with torch.no_grad(): |
|
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype) |
|
|
|
nt.requires_grad_(requires_grad) |
|
if pin_memory: |
|
nt = nt.pin_memory() |
|
|
|
return nt |
|
else: |
|
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") |
|
|
|
|
|
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor: |
|
r""" |
|
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows |
|
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor |
|
shows only the elements in the interval `[start, start+length)`. As nested representations |
|
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length` |
|
can also be tensors of shape `tensor.shape[0]`. |
|
|
|
There's some differences depending on the layout you use for the nested tensor. If using strided layout, |
|
torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while |
|
jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular |
|
representation is really useful for representing kv-caches in Transformer models, as specialized |
|
SDPA kernels can deal with format easily, resulting in performance improvements. |
|
|
|
|
|
Args: |
|
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data |
|
for the nested tensor if using the jagged layout or will be copied for the strided layout. |
|
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the |
|
jagged layout, while strided supports all dim |
|
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation |
|
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op |
|
|
|
Keyword arguments: |
|
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. |
|
Only strided and jagged layouts are supported. Default: if None, the strided layout. |
|
|
|
Example:: |
|
|
|
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) |
|
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64) |
|
>>> narrow_base = torch.randn(5, 10, 20) |
|
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged) |
|
>>> nt_narrowed.is_contiguous() |
|
False |
|
""" |
|
if not isinstance(start, (int, SymInt, Tensor)): |
|
raise RuntimeError("start must be an integer or a tensor") |
|
|
|
if not isinstance(length, (int, SymInt, Tensor)): |
|
raise RuntimeError("length must be an integer or a tensor") |
|
|
|
if layout == torch.strided: |
|
if isinstance(start, Tensor) or isinstance(length, Tensor): |
|
raise RuntimeError("start and length must be integers for the strided layout NT impl") |
|
|
|
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length) |
|
elif layout == torch.jagged: |
|
if dim != 1: |
|
raise RuntimeError("jagged layout only supports dim=1") |
|
|
|
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths |
|
|
|
if isinstance(start, (int, SymInt)): |
|
start = torch.tensor([start], device=tensor.device, dtype=torch.int64) |
|
|
|
if isinstance(length, (int, SymInt)): |
|
length = torch.tensor([length], device=tensor.device, dtype=torch.int64) |
|
|
|
nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length) |
|
else: |
|
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}") |
|
|
|
return nt |
|
|
|
|
|
def nested_tensor_from_jagged( |
|
values: Tensor, |
|
offsets: Optional[Tensor] = None, |
|
lengths: Optional[Tensor] = None, |
|
jagged_dim: Optional[int] = None, |
|
) -> Tensor: |
|
r""" |
|
Constructs a jagged layout nested tensor from the given jagged components. The jagged layout |
|
consists of a required values buffer with the jagged dimension packed into a single dimension. |
|
The offsets / lengths metadata determines how this dimension is split into batch elements |
|
and are expected to be allocated on the same device as the values buffer. |
|
|
|
Expected metadata formats: |
|
* offsets: Indices within the packed dimension splitting it into heterogeneously-sized |
|
batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6 |
|
should be conceptually split into batch elements of length [2, 1, 3]. Note that both the |
|
beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1). |
|
* lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3] |
|
indicates that a packed jagged dim of size 6 should be conceptually split into batch |
|
elements of length [2, 1, 3]. |
|
|
|
Note that it can be useful to provide both offsets and lengths. This describes a nested tensor |
|
with "holes", where the offsets indicate the start position of each batch item and the length |
|
specifies the total number of elements (see example below). |
|
|
|
The returned jagged layout nested tensor will be a view of the input values tensor. |
|
|
|
Args: |
|
values (:class:`torch.Tensor`): The underlying buffer in the shape of |
|
(sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension, |
|
with the offsets / lengths metadata used to distinguish batch elements. |
|
offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1. |
|
lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B. |
|
jagged_dim (optional int): Indicates which dimension in values is the packed jagged |
|
dimension. If None, this is set to dim=1 (i.e. the dimension immediately following |
|
the batch dimension). Default: None |
|
|
|
Example:: |
|
|
|
>>> values = torch.randn(12, 5) |
|
>>> offsets = torch.tensor([0, 3, 5, 6, 10, 12]) |
|
>>> nt = nested_tensor_from_jagged(values, offsets) |
|
>>> # 3D shape with the middle dimension jagged |
|
>>> nt.shape |
|
torch.Size([5, j2, 5]) |
|
>>> # Length of each item in the batch: |
|
>>> offsets.diff() |
|
tensor([3, 2, 1, 4, 2]) |
|
|
|
>>> values = torch.randn(6, 5) |
|
>>> offsets = torch.tensor([0, 2, 3, 6]) |
|
>>> lengths = torch.tensor([1, 1, 2]) |
|
>>> # NT with holes |
|
>>> nt = nested_tensor_from_jagged(values, offsets, lengths) |
|
>>> a, b, c = nt.unbind() |
|
>>> # Batch item 1 consists of indices [0, 1) |
|
>>> torch.equal(a, values[0:1, :]) |
|
True |
|
>>> # Batch item 2 consists of indices [2, 3) |
|
>>> torch.equal(b, values[2:3, :]) |
|
True |
|
>>> # Batch item 3 consists of indices [3, 5) |
|
>>> torch.equal(c, values[3:5, :]) |
|
True |
|
""" |
|
if offsets is None: |
|
if lengths is None: |
|
raise RuntimeError( |
|
"nested_tensor_from_jagged(): At least one of offsets or lengths is required." |
|
) |
|
else: |
|
|
|
|
|
offsets = F.pad(lengths.cumsum(0), (1, 0)) |
|
lengths = None |
|
|
|
if jagged_dim is None: |
|
jagged_dim = 1 |
|
|
|
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths |
|
|
|
return nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=jagged_dim) |
|
|