|
|
|
import functools |
|
import math |
|
import operator |
|
|
|
import torch |
|
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention |
|
|
|
from .nested_tensor import NestedTensor |
|
from typing import * |
|
import torch.nn.functional as F |
|
from torch.fx.operator_schemas import normalize_function |
|
|
|
__all__: List[Any] = [] |
|
|
|
JAGGED_OPS_TABLE: Dict[Any, Any] = {} |
|
|
|
|
|
|
|
|
|
def _outer_to_inner_dim(ndim, dim): |
|
assert dim >= 0 and dim < ndim |
|
return 0 if dim < 2 else dim - 1 |
|
|
|
|
|
def _wrap_jagged_dim( |
|
ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False |
|
): |
|
from torch._prims_common import canonicalize_dims |
|
|
|
wrapped = canonicalize_dims(ndim, dim) |
|
if wrapped == 1: |
|
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1") |
|
elif wrapped == 0 and not allow_batch_dim: |
|
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0") |
|
return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped |
|
|
|
|
|
def _wrap_jagged_dims(ndim, dims, op_name): |
|
|
|
|
|
from torch._prims_common import canonicalize_dims |
|
|
|
wrapped_dims = [canonicalize_dims(ndim, d) for d in dims] |
|
|
|
|
|
zero_in_dims = 0 in wrapped_dims |
|
one_in_dims = 1 in wrapped_dims |
|
if zero_in_dims ^ one_in_dims: |
|
apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch") |
|
raise RuntimeError( |
|
f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}" |
|
" dimension is not supported for NestedTensor" |
|
) |
|
return ( |
|
tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0), |
|
zero_in_dims, |
|
) |
|
|
|
|
|
def check_schema(schema_str: str, func, *args, **kwargs) -> None: |
|
named_arg_types = schema_str.split(", ") |
|
num_optional_args = [x.endswith("?") for x in named_arg_types].count(True) |
|
min_args = len(named_arg_types) - num_optional_args |
|
|
|
|
|
if named_arg_types[-1] == "...": |
|
named_arg_types = named_arg_types[:-1] |
|
else: |
|
if not (len(args) >= min_args and len(args) <= len(named_arg_types)): |
|
raise ValueError( |
|
f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " |
|
f"arguments and at most {len(named_arg_types)} arguments, but got: " |
|
f"{len(args)} arguments" |
|
) |
|
|
|
arg_type_check_fns = { |
|
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), |
|
"jt": lambda x: isinstance(x, NestedTensor) |
|
and x._lengths is None |
|
and x._ragged_idx == 1, |
|
"jt_all": lambda x: isinstance( |
|
x, NestedTensor |
|
), |
|
"any": lambda x: True, |
|
} |
|
for i, named_arg_type in enumerate(named_arg_types): |
|
name, arg_type = named_arg_type.split(": ") |
|
is_optional = arg_type.endswith("?") |
|
normalized_arg_type = arg_type[:-1] if is_optional else arg_type |
|
if normalized_arg_type not in arg_type_check_fns.keys(): |
|
raise AssertionError(f"Unknown arg type: {normalized_arg_type}") |
|
|
|
if i >= len(args): |
|
if not is_optional: |
|
raise ValueError( |
|
f"NestedTensor {func.__name__}({schema_str}) " |
|
f"missing required argument: {name}" |
|
) |
|
continue |
|
|
|
_check_fn = arg_type_check_fns[normalized_arg_type] |
|
|
|
def check_fn(x, is_optional=is_optional): |
|
if is_optional: |
|
return x is None or _check_fn(x) |
|
else: |
|
return _check_fn(x) |
|
|
|
if not check_fn(args[i]): |
|
type_to_desc = { |
|
"t": "tensor", |
|
"t?": "optional tensor", |
|
"jt": "contiguous jagged layout NestedTensor", |
|
"jt_all": "jagged layout NestedTensor", |
|
"any": "<any type>", |
|
} |
|
|
|
raise ValueError( |
|
f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a " |
|
f"{type_to_desc[arg_type]}" |
|
) |
|
|
|
|
|
def check_ragged_dim_same( |
|
func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str |
|
) -> None: |
|
|
|
if a._size[a._ragged_idx] != b._size[b._ragged_idx]: |
|
raise RuntimeError( |
|
f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " |
|
"same exact offsets tensor." |
|
) |
|
|
|
|
|
|
|
|
|
def raggedness_matches(nt, size): |
|
end = nt._ragged_idx + 1 |
|
nt_ragged = nt._size[:end] |
|
size_ragged = size[:end] |
|
return len(nt_ragged) == len(size_ragged) and ( |
|
all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged)) |
|
) |
|
|
|
|
|
def squeeze_leading_ones(t): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while t.shape[0] == 1: |
|
t = t.squeeze(0) |
|
return t |
|
|
|
|
|
def register_func(tables, aten_ops, schema_str): |
|
if not isinstance(aten_ops, list): |
|
aten_ops = [aten_ops] |
|
if not isinstance(tables, list): |
|
tables = [tables] |
|
|
|
def wrapper(func): |
|
for aten_op in aten_ops: |
|
|
|
def get_inner(aten_op): |
|
def inner(*args, **kwargs): |
|
check_schema(schema_str, func, *args, **kwargs) |
|
return func(aten_op, *args, **kwargs) |
|
|
|
return inner |
|
|
|
for table in tables: |
|
table[aten_op] = get_inner(aten_op) |
|
return func |
|
|
|
return wrapper |
|
|
|
|
|
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) |
|
|
|
|
|
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: |
|
dispatch_func = JAGGED_OPS_TABLE.get(func, None) |
|
if dispatch_func is not None: |
|
return dispatch_func |
|
|
|
|
|
if torch.Tag.pointwise in func.tags: |
|
|
|
num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args) |
|
if num_tensor_args == 1: |
|
check_schema("self: jt_all, ...", func, *args, **kwargs) |
|
return functools.partial(jagged_unary_pointwise, func) |
|
elif num_tensor_args == 2: |
|
check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs) |
|
return functools.partial(jagged_binary_pointwise, func) |
|
|
|
return None |
|
|
|
|
|
def extract_kwargs(arg): |
|
kwargs = { |
|
"offsets": arg.offsets(), |
|
"_metadata_cache": arg._metadata_cache, |
|
"_ragged_idx": arg._ragged_idx, |
|
} |
|
return kwargs |
|
|
|
|
|
def jagged_unary_pointwise(func, *args, **kwargs): |
|
return NestedTensor( |
|
func(args[0]._values, *args[1:], **kwargs), **extract_kwargs(args[0]) |
|
) |
|
|
|
|
|
def jagged_binary_pointwise(func, *args, **kwargs): |
|
a, b = args[0], args[1] |
|
assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor) |
|
|
|
mismatch_error_msg = ( |
|
"cannot call binary pointwise function {} with inputs of shapes {} and {}" |
|
) |
|
|
|
if isinstance(a, NestedTensor) and isinstance(b, NestedTensor): |
|
|
|
|
|
if raggedness_matches(a, b._size): |
|
return NestedTensor( |
|
func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a) |
|
) |
|
raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size)) |
|
|
|
a_is_nt = isinstance(a, NestedTensor) |
|
extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nt, t = (a, b) if a_is_nt else (b, a) |
|
|
|
if t.dim() > nt.dim(): |
|
raise NotImplementedError("NYI: broadcasting NT with T with larger dim") |
|
t_squeezed = squeeze_leading_ones(t) |
|
if nt.dim() >= t_squeezed.dim() + 2: |
|
lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values) |
|
return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs) |
|
|
|
|
|
|
|
|
|
if a.dim() == b.dim(): |
|
|
|
|
|
if a.shape[0] != b.shape[0]: |
|
raise RuntimeError( |
|
mismatch_error_msg.format(func.__name__, a.shape, b.shape) |
|
) |
|
|
|
|
|
|
|
|
|
outputs = [] |
|
for a_comp, b_comp in zip(a.unbind(), b.unbind()): |
|
outputs.append(func(a_comp, b_comp, *args[2:], **kwargs)) |
|
new_values = torch.cat(outputs, dim=0) |
|
return NestedTensor(new_values, **extracted_kwargs) |
|
|
|
|
|
|
|
raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape)) |
|
|
|
|
|
def jagged_torch_function(func, *args, **kwargs): |
|
|
|
|
|
if func is torch._C._nn.scaled_dot_product_attention: |
|
return jagged_scaled_dot_product_attention(*args, **kwargs) |
|
|
|
|
|
if func.__name__ == "flatten": |
|
|
|
def _flatten_sig(input, start_dim=0, end_dim=-1): |
|
pass |
|
|
|
_, new_kwargs = normalize_function( |
|
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
|
|
start_dim = _wrap_jagged_dim( |
|
inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False |
|
) |
|
end_dim = _wrap_jagged_dim( |
|
inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False |
|
) |
|
|
|
if start_dim == end_dim: |
|
return inp |
|
|
|
product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1]) |
|
new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :]) |
|
|
|
return inp.reshape(*new_shape) |
|
|
|
raise NotImplementedError(func) |
|
|
|
|
|
@register_jagged_func( |
|
[ |
|
torch.ops.aten.is_non_overlapping_and_dense.default, |
|
torch.ops.aten.sym_size.default, |
|
torch.ops.aten.dim.default, |
|
torch.ops.aten.numel.default, |
|
torch.ops.aten.sym_numel.default, |
|
torch.ops.aten.sym_stride.default, |
|
torch.ops.aten.sym_storage_offset.default, |
|
], |
|
"self: jt_all", |
|
) |
|
def tensor_attr_supported_getter(func, *args, **kwargs): |
|
if func == torch.ops.aten.is_non_overlapping_and_dense.default: |
|
return False |
|
|
|
if func == torch.ops.aten.sym_size.default: |
|
return args[0]._size |
|
|
|
if func == torch.ops.aten.dim.default: |
|
return len(args[0]._size) |
|
|
|
if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default): |
|
if args[0]._lengths is not None: |
|
return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:])) |
|
return args[0]._values.numel() |
|
|
|
if func == torch.ops.aten.sym_stride.default: |
|
return args[0]._strides |
|
|
|
if func == torch.ops.aten.sym_storage_offset.default: |
|
return args[0]._values.storage_offset() |
|
|
|
|
|
@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all") |
|
def prim_layout_default(func, *args, **kwargs): |
|
return torch.jagged |
|
|
|
|
|
@register_jagged_func( |
|
[torch.ops.aten.size.default], |
|
"self: jt_all", |
|
) |
|
def tensor_attr_unsupported_getter(func, *args, **kwargs): |
|
if func == torch.ops.aten.size.default: |
|
raise RuntimeError( |
|
"NestedTensors does not support directly calling torch.ops.aten.size " |
|
"please use `nested_tensor.size()` instead." |
|
) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all") |
|
def is_contiguous_general(func, *args, **kwargs): |
|
from torch._prims_common import is_contiguous_for_memory_format |
|
|
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
inp = new_kwargs.pop("input") |
|
|
|
|
|
if inp.lengths() is not None: |
|
return False |
|
|
|
new_kwargs["memory_format"] = new_kwargs.get( |
|
"memory_format", torch.contiguous_format |
|
) |
|
if new_kwargs["memory_format"] == torch.preserve_format: |
|
return True |
|
return is_contiguous_for_memory_format(inp._values, **new_kwargs) |
|
|
|
|
|
register_jagged_func( |
|
torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?" |
|
)(is_contiguous_general) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?") |
|
def linear_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.linear_backward.default, |
|
"self: jt, grad_output: jt, weight: t, output_mask: any", |
|
) |
|
def linear_backward_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
grad_output = new_kwargs.pop("grad_output") |
|
weight = new_kwargs.pop("weight") |
|
|
|
check_ragged_dim_same(func, inp, "self", grad_output, "grad_output") |
|
ds = NestedTensor( |
|
torch.mm(grad_output._values, weight), **extract_kwargs(grad_output) |
|
) |
|
dw = torch.mm(grad_output._values.T, inp._values) |
|
db = None |
|
return (ds, dw, db) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all") |
|
def to_copy_default(func, *args, **kwargs): |
|
from .nested_tensor import _tensor_symint_registry |
|
|
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
new_kwargs.pop("layout") |
|
|
|
new_values = func(inp._values, **new_kwargs) |
|
new_offsets = inp._offsets.to(device=new_values.device) |
|
_tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets] |
|
inp_kwargs = extract_kwargs(inp) |
|
inp_kwargs["offsets"] = new_offsets |
|
|
|
return NestedTensor(new_values, **inp_kwargs) |
|
|
|
|
|
register_jagged_func( |
|
[ |
|
torch.ops.aten.empty_like.default, |
|
torch.ops.aten.ones_like.default, |
|
torch.ops.aten.zeros_like.default, |
|
torch.ops.aten.randn_like.default, |
|
torch.ops.aten.detach.default, |
|
], |
|
"self: jt_all", |
|
)(jagged_unary_pointwise) |
|
|
|
|
|
register_jagged_func( |
|
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" |
|
)(jagged_unary_pointwise) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?" |
|
) |
|
def native_dropout_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
out1, out2 = func(inp._values, **new_kwargs) |
|
return ( |
|
NestedTensor(out1, **extract_kwargs(inp)), |
|
NestedTensor(out2, **extract_kwargs(inp)), |
|
) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.native_dropout_backward.default, |
|
"grad_output: jt, mask: jt, scale: any", |
|
) |
|
def native_dropout_backward_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
grad_output = new_kwargs.pop("grad_output") |
|
mask = new_kwargs.pop("mask") |
|
return NestedTensor( |
|
func(grad_output._values, mask._values, **new_kwargs), |
|
**extract_kwargs(grad_output), |
|
) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?") |
|
def prod_dim_int(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
|
|
if not new_kwargs["keepdim"]: |
|
raise RuntimeError("prod(): keepdim=True must be set for NestedTensor") |
|
dim = new_kwargs["dim"] |
|
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "prod") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0])) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any" |
|
) |
|
def split_tensor(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split") |
|
|
|
return tuple( |
|
NestedTensor(values=x, **extract_kwargs(inp)) |
|
for x in func(inp._values, **new_kwargs) |
|
) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any" |
|
) |
|
def split_with_sizes_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
new_kwargs["dim"] = _wrap_jagged_dim( |
|
inp.dim(), new_kwargs["dim"], "split_with_sizes" |
|
) |
|
|
|
return [ |
|
NestedTensor(values=x, **extract_kwargs(inp)) |
|
for x in func(inp._values, **new_kwargs) |
|
] |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?") |
|
def chunk_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
new_kwargs["dim"] = _wrap_jagged_dim( |
|
inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True |
|
) |
|
|
|
if new_kwargs["dim"] == 0: |
|
chunks = new_kwargs["chunks"] |
|
dim0_size = inp._size[0] |
|
chunk_size = math.ceil(dim0_size / chunks) |
|
|
|
|
|
lengths = inp._offsets.diff() |
|
chunked_lengths = lengths.chunk(chunks) |
|
chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths] |
|
chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] |
|
nested_kwargs = [ |
|
{"offsets": per_offsets, "_ragged_idx": inp._ragged_idx} |
|
for per_offsets in chunked_offsets |
|
] |
|
|
|
|
|
split_sizes = [x.sum().item() for x in chunked_lengths] |
|
chunk_values = inp._values.split(split_sizes) |
|
|
|
return [ |
|
NestedTensor(values=chunk_values[i], **(nested_kwargs[i])) |
|
for i in range(0, chunk_size) |
|
] |
|
else: |
|
return [ |
|
NestedTensor(values=x, **extract_kwargs(inp)) |
|
for x in func(inp._values, **new_kwargs) |
|
] |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?") |
|
def unbind_int(func, *args, **kwargs): |
|
|
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
dim = new_kwargs["dim"] |
|
if dim != 0: |
|
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0") |
|
|
|
inp = new_kwargs.pop("input") |
|
values = inp.values() |
|
offsets = inp.offsets() |
|
lengths = inp.lengths() |
|
ragged_idx = inp._ragged_idx |
|
|
|
if lengths is None: |
|
return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1)) |
|
|
|
if ragged_idx <= 0: |
|
raise RuntimeError( |
|
"unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" |
|
) |
|
for i in range(lengths.shape[0]): |
|
if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]: |
|
raise RuntimeError( |
|
"unbind(): nested tensor offsets and lengths do not match ragged_idx dimension" |
|
) |
|
return [ |
|
torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i]) |
|
for i in range(lengths.shape[0]) |
|
] |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any") |
|
def squeeze_dim(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
values = inp._values |
|
|
|
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze") |
|
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any") |
|
def unsqueeze_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
values = inp._values |
|
|
|
|
|
dim = new_kwargs["dim"] |
|
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze") |
|
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any") |
|
def cat_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
tensors = new_kwargs.pop("tensors") |
|
|
|
|
|
nested = [t for t in tensors if t.is_nested] |
|
assert len(nested) > 0 |
|
first = nested[0] |
|
tensors = [t if t.is_nested else t.expand_as(first) for t in tensors] |
|
|
|
|
|
dim = new_kwargs["dim"] |
|
new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat") |
|
|
|
return NestedTensor( |
|
func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) |
|
) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any") |
|
def matmul_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
other = new_kwargs.pop("other") |
|
|
|
if inp.is_nested and not other.is_nested: |
|
return NestedTensor( |
|
func(inp._values, other, **new_kwargs), **extract_kwargs(inp) |
|
) |
|
elif inp.is_nested and other.is_nested: |
|
|
|
if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size): |
|
return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp)) |
|
|
|
raise RuntimeError( |
|
f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}" |
|
) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?" |
|
) |
|
def expand_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
size = new_kwargs["size"] |
|
|
|
assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit")) |
|
if not raggedness_matches(inp, size): |
|
raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}") |
|
|
|
expand_arg = [-1, *size[2:]] |
|
return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt") |
|
def expand_as_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
other = new_kwargs.pop("other") |
|
|
|
return NestedTensor(func(inp, other._values), **extract_kwargs(other)) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt") |
|
def where_self(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
condition = new_kwargs.pop("condition") |
|
inp = new_kwargs.pop("input") |
|
other = new_kwargs.pop("other") |
|
|
|
assert condition._size == other._size == inp._size |
|
|
|
return NestedTensor( |
|
func(condition._values, inp._values, other._values, **new_kwargs), |
|
**extract_kwargs(condition), |
|
) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?") |
|
def _pin_memory_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?") |
|
def is_pinned_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
return func(inp._values, **new_kwargs) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all" |
|
) |
|
def is_same_size_default(func, *args, **kwargs): |
|
return args[0]._size == args[1]._size |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?" |
|
) |
|
def sum_dim_IntList(func, *args, **kwargs): |
|
|
|
|
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
inp = new_kwargs.pop("input") |
|
assert inp._ragged_idx == 1 |
|
new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims( |
|
inp.dim(), new_kwargs["dim"], "sum" |
|
) |
|
|
|
if not ragged_reduced_away: |
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
else: |
|
|
|
out = func(inp._values, **new_kwargs) |
|
if new_kwargs["keepdim"]: |
|
out = out.unsqueeze(0) |
|
return out |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any" |
|
) |
|
def transpose_int(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
from torch._prims_common import canonicalize_dims |
|
|
|
inp = new_kwargs.pop("input") |
|
dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"])) |
|
|
|
if inp._lengths is not None: |
|
raise ValueError( |
|
"transpose(): not supported on jagged layout nested tensor with holes" |
|
) |
|
|
|
|
|
|
|
|
|
if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx: |
|
if dim0 == 0 or dim1 == 0: |
|
raise ValueError( |
|
"Transpose is not supported on the batch dimension for jagged NT" |
|
) |
|
if dim0 == inp._ragged_idx: |
|
to_dim = dim1 |
|
else: |
|
to_dim = dim0 |
|
inp_kwargs = extract_kwargs(inp) |
|
inp_kwargs["_ragged_idx"] = to_dim |
|
return NestedTensor( |
|
inp.values().transpose( |
|
_outer_to_inner_dim(len(inp._size), dim0), |
|
_outer_to_inner_dim(len(inp._size), dim1), |
|
), |
|
**inp_kwargs, |
|
) |
|
|
|
new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose") |
|
new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func( |
|
[torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default], |
|
"self: jt_all, size: any", |
|
) |
|
def view_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
size = new_kwargs.pop("size") |
|
|
|
if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size): |
|
raise RuntimeError( |
|
f"view(): does not support ragged_idx != 1 except when inp._size == size. " |
|
f"inp._size is ({inp._size}) and size is ({size})." |
|
) |
|
|
|
|
|
if len(size) < 3 or not raggedness_matches(inp, size): |
|
raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_inner_size(inner_idx): |
|
nonlocal inp, size |
|
if inner_idx == inp._ragged_idx - 1: |
|
return inp._values.size(inner_idx) |
|
else: |
|
return size[inner_idx + 1] |
|
|
|
inner_size = [get_inner_size(i) for i in range(len(size) - 1)] |
|
|
|
return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.native_layer_norm.default, |
|
"input: jt, normalized_shape: any, weight: any?, bias: any?, eps: any", |
|
) |
|
def native_layer_norm_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
normalized_shape = new_kwargs["normalized_shape"] |
|
|
|
|
|
if inp.dim() < 3 or (inp.dim() - len(normalized_shape)) < 2: |
|
raise RuntimeError( |
|
"layer_norm(): normalizing over ragged dim not supported for nested tensors" |
|
) |
|
|
|
output, mean, std = func(inp._values, **new_kwargs) |
|
return (NestedTensor(output, **extract_kwargs(inp)), mean, std) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.native_layer_norm_backward.default, |
|
"grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any", |
|
) |
|
def native_layer_norm_backward_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
grad_out = new_kwargs.pop("grad_out") |
|
inp = new_kwargs.pop("input") |
|
d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs) |
|
if d_input is None: |
|
return (None, d_gamma, d_beta) |
|
|
|
return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any") |
|
def select_int(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "select") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.slice.Tensor, |
|
"self: jt, dim: any?, start: any?, end: any?, step: any?", |
|
) |
|
def slice_tensor(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.convolution.default, |
|
"input: jt, weight: t, bias: t?, stride: any, padding: any, " |
|
"dilation: any, transposed: any, output_padding: any, groups: any", |
|
) |
|
def convolution_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.mean.dim, "self: jt, dim: any?, keepdim: any, dtype: any?" |
|
) |
|
def mean_dim(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
new_kwargs["dim"] = [_wrap_jagged_dim(inp.dim(), new_kwargs["dim"][0], "mean")] |
|
|
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any") |
|
def stack_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
|
|
tensors = new_kwargs.pop("tensors") |
|
for t in tensors: |
|
if not isinstance(t, NestedTensor): |
|
raise RuntimeError("stack(): expected all nested tensors inputs") |
|
|
|
if t.dim() != tensors[0].dim(): |
|
raise RuntimeError( |
|
"stack(): expected all nested tensors to have the same dim" |
|
) |
|
|
|
if not raggedness_matches(t, tensors[0].shape): |
|
raise RuntimeError( |
|
"stack(): expected all nested tensors to have the same nested structure" |
|
) |
|
|
|
new_kwargs["dim"] = _wrap_jagged_dim( |
|
tensors[0].dim() + 1, new_kwargs["dim"], "stack" |
|
) |
|
|
|
return NestedTensor( |
|
func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) |
|
) |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten.embedding.default, |
|
"weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?", |
|
) |
|
def embedding_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
|
|
indices = new_kwargs.pop("indices") |
|
weight = new_kwargs.pop("weight") |
|
|
|
return NestedTensor( |
|
func(weight, indices._values, **new_kwargs), **extract_kwargs(indices) |
|
) |
|
|
|
|
|
@register_jagged_func( |
|
[ |
|
torch.ops.aten.values.default, |
|
torch.ops.aten._nested_get_values.default, |
|
], |
|
"self: jt_all", |
|
) |
|
def values_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
|
|
|
|
|
|
return inp._values.detach() |
|
|
|
|
|
@register_jagged_func( |
|
torch.ops.aten._nested_view_from_jagged.default, |
|
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?", |
|
) |
|
def _nested_view_from_jagged_default(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
values, offsets, lengths = ( |
|
new_kwargs["input"], |
|
new_kwargs["offsets"], |
|
new_kwargs["lengths"], |
|
) |
|
ragged_idx = new_kwargs["ragged_idx"] |
|
|
|
return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx) |
|
|
|
|
|
@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") |
|
def _nested_get_offsets(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
return inp._offsets |
|
|
|
|
|
@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all") |
|
def _nested_get_lengths(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
return inp._lengths |
|
|
|
|
|
@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all") |
|
def _nested_get_ragged_idx(func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
inp = new_kwargs.pop("input") |
|
return inp._ragged_idx |
|
|
|
|
|
|
|
@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") |
|
def _nested_get_jagged_dummy(func, *args, **kwargs): |
|
from torch.nested._internal.nested_tensor import _nt_view_dummy |
|
|
|
return _nt_view_dummy() |
|
|
|
|
|
with torch.library._scoped_library("aten", "IMPL") as aten: |
|
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU") |
|
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA") |
|
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta") |
|
|