Spaces:
Running
Running
File size: 36,575 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 |
# mypy: ignore-errors
import functools
import itertools
import math
import sys
from typing import Callable, Union
import torch
import torch._custom_op
import torch._logging
from torch._ops import OpOverload
from torch._prims_common import (
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
is_float_dtype,
is_integer_dtype,
)
from torch._subclasses.fake_tensor import (
DataDependentOutputException,
DynamicOutputShapeException,
FakeTensor,
in_kernel_invocation_manager,
run_fallback_kernel,
UnsupportedOperatorException,
)
from torch.fx.operator_schemas import normalize_function
from torch.utils._stats import count_label
pytree = torch.utils._pytree
__all__ = [
"op_implementations_checks",
"get_fast_op_impls",
"stride_incorrect_op",
"has_meta",
]
op_implementations_dict = {}
op_implementations_checks = []
aten = torch._ops.ops.aten
def ordered_set(*items):
return dict.fromkeys(items, True)
# This function indicates if the backend device
# supports non-contiguous tensors
def is_noncontiguous_supported(device):
if device.type == "hpu":
return False
return True
_like_tensor_constructors = ordered_set(
aten.empty_like.default,
aten.empty_like.out,
aten.full_like.default,
aten.full_like.out,
aten.ones_like.default,
aten.ones_like.out,
aten.rand_like.default,
aten.rand_like.out,
aten.randn_like.default,
aten.randn_like.out,
aten.randint_like.default,
aten.randint_like.out,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
aten.zeros_like.default,
aten.zeros_like.out,
aten.new_empty.default,
aten.new_empty.out,
aten.new_empty_strided.default,
aten.new_empty_strided.out,
aten.new_full.default,
aten.new_full.out,
aten.new_zeros.default,
aten.new_zeros.out,
aten.new_ones.default,
aten.new_ones.out,
)
_device_not_kwarg_ops = ordered_set(
aten._resize_output_.default,
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
aten.pin_memory.default,
aten.is_pinned.default,
aten.to.device,
aten.to.prim_Device,
aten._pin_memory.default,
aten._pin_memory.out,
aten._resize_output.default,
aten._resize_output.out,
)
# this op is never actually used
_non_kwarg_device_constructors = (aten._list_to_tensor,)
def contains_tensor_types(type):
tensor_type = torch._C.TensorType.get()
return type.isSubtypeOf(tensor_type) or any(
contains_tensor_types(e) for e in type.containedTypes()
)
@functools.lru_cache(None)
def _is_tensor_constructor(func: OpOverload):
assert isinstance(func, OpOverload)
schema = func._schema
if any(contains_tensor_types(arg.type) for arg in schema.arguments):
return False
# TODO: no real reason to restrict multiple outputs
return (
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
)
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def impl_decorator(op_impl):
if isinstance(run_impl_check, OpOverload):
assert (
run_impl_check not in op_implementations_dict
), f"duplicate registration: {run_impl_check}"
op_implementations_dict[run_impl_check] = op_impl
elif isinstance(run_impl_check, (list, tuple)):
for op in run_impl_check:
register_op_impl(op)(op_impl)
else:
assert callable(run_impl_check)
op_implementations_checks.append((run_impl_check, op_impl))
return op_impl
return impl_decorator
@register_op_impl(op_implementations_dict.__contains__)
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
@register_op_impl(_is_tensor_constructor)
@register_op_impl([*_like_tensor_constructors])
def constructors(fake_mode, func, *args, **kwargs):
assert func not in _non_kwarg_device_constructors
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
if "names" in kwargs:
raise UnsupportedOperatorException(
"torch.compile doesn't support named tensors"
)
if func in _like_tensor_constructors:
default_device = new_kwargs["input"].device
# TODO: file issue
args = (new_kwargs.pop("input"),)
else:
# cpu is default device if none is specified
default_device = torch.device("cpu")
args = ()
out_device = new_kwargs.pop("device", None)
out_device = out_device if out_device is not None else default_device
new_kwargs["device"] = torch.device("meta")
# _like constructors have fake tensor inputs (maybe this causes the non-like
# to fail? hmmm)
with in_kernel_invocation_manager(fake_mode):
r = func(*args, **new_kwargs)
return FakeTensor(fake_mode, r, out_device)
@register_op_impl(aten.to.prim_Device)
@register_op_impl(aten.to.device)
def non_kwarg_to(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
)
input_device = new_kwargs["device"]
out_device = input_device if input_device else new_kwargs["input"].device
new_kwargs["device"] = torch.device("meta")
inp = new_kwargs.pop("input")
with in_kernel_invocation_manager(fake_mode):
r = func(inp, **new_kwargs)
# TODO: I think this does the wrong thing if r is inp
return fake_mode.fake_tensor_converter.from_meta_and_device(
fake_mode, r, out_device
)
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False
op_name = op.name()
if "fft" in op_name:
return True
return False
# These operators have meta implementations with incorrect strides
@register_op_impl(stride_incorrect_op)
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
# This is a workaround for meta implmentations with incorrect strides
def is_symbolic(x):
if isinstance(x, FakeTensor):
return x._has_symbolic_sizes_strides
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return True
return False
# For static shapes, we can fall back to eager for the real strides
if fake_mode.allow_fallback_kernels:
require_dynamic = any(
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
raise UnsupportedOperatorException(func)
# Dont default to default device handling,
# since the device of `the_template` is ignored
@register_op_impl(aten.resize_as_.default)
def resize_as_(fake_mode, func, *args, **kwargs):
with in_kernel_invocation_manager(fake_mode):
return func(*args, **kwargs)
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
# TODO: remove me
return constructors(fake_mode, func, *args, **kwargs)
# index.Tensor data-dependent in only some conditions
@register_op_impl(
lambda func: torch.Tag.dynamic_output_shape in func.tags
and func
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
)
def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func)
@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
raise DynamicOutputShapeException(func)
output_size = fake_mode.shape_env.create_unbacked_symint()
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
_constrain_range_for_size(output_size)
# TODO: consider a memo
return repeats.new_empty(output_size)
@register_op_impl(torch.ops.aten._local_scalar_dense.default)
def local_scalar_dense(fake_mode, func, arg):
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
# Without symints/symfloats, cannot handle this
raise DataDependentOutputException(func)
if is_float_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symfloat()
elif is_integer_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symint()
elif is_boolean_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symbool()
else:
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
@register_op_impl(torch.ops.aten.nonzero.default)
def nonzero(fake_mode, func, arg):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if arg.nonzero_memo is None:
nnz = fake_mode.shape_env.create_unbacked_symint()
# This is unsound, but it works well in practice
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
# TODO: Add a config knob to turn off this unsound behavior
#
# NB: If numel < 2, the bounds here might be COMPLETELY
# disjoint with what can actually occur. But this is fine:
# remember, the hypothesis is that if your later code works
# with N >= 2, it will work with N = 1 and N = 0.
maxval = sys.maxsize - 1
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(arg.numel()):
# Don't upgrade the range if numel is less than two, since we then
# have an empty range which makes things go explodey. We also
# don't allow for 2 because that would specialize the unbacked
# SymInt to 2, which is also likely to be buggy.
if arg.numel() > 2:
maxval = int(arg.numel())
_constrain_range_for_size(nnz, max=maxval)
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)
@register_op_impl(torch.ops.aten.masked_select.default)
def masked_select(fake_mode, func, self, mask):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
nnz = fake_mode.shape_env.create_unbacked_symint()
# see nonzero for commentary
maxval = sys.maxsize - 1
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(self.numel()):
if self.numel() > 2:
maxval = int(self.numel())
_constrain_range_for_size(nnz, max=maxval)
return self.new_empty((nnz,))
# NB: this must be ordered after local_scalar_dense
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
def data_dep(fake_mode, func, *args, **kwargs):
raise DataDependentOutputException(func)
# Bool Indices get Expanded as Masks
# See: IndexingUtils.h:expandTensors
def check_no_bool_index_tensors(func, self, indices):
for index in indices:
if index is not None and index.dtype in (torch.bool, torch.uint8):
raise DynamicOutputShapeException(func)
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
with in_kernel_invocation_manager(fake_mode):
out = func(*args, **kwargs)
if not is_noncontiguous_supported(out_device):
out = out.new_empty(out.shape)
if out is new_kwargs["input"]:
return out # copy_
return FakeTensor(fake_mode, out, out_device)
_is_builtin_namespaces = ordered_set("aten", "prims", "prim")
def is_builtin(op):
return op.namespace in _is_builtin_namespaces
def has_meta(func):
return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
@register_op_impl(
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
)
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
tensor_lists = []
for arg in itertools.chain(args, kwargs.values()):
if (
isinstance(arg, (list, tuple))
and len(arg)
and isinstance(arg[0], torch.Tensor)
):
tensor_lists.append(arg)
try:
with in_kernel_invocation_manager(fake_mode):
out_meta = func(*args, **kwargs)
except NotImplementedError as not_implemented_error:
return NotImplemented
if not out_meta:
return out_meta
assert tensor_lists
out_fake = []
for i, meta_t in enumerate(out_meta):
device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
out_fake.append(
fake_mode.fake_tensor_converter.from_meta_and_device(
fake_mode, meta_t, device
)
)
return out_fake
# Dont default to default device handling,
# Since op can take in non-zero sized cpu
# index tensors with cuda self
@register_op_impl(aten.index.Tensor)
def index_tensor(fake_mode, func, *args, **kwargs):
from torch._meta_registrations import meta_index_Tensor
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
# ensure nonzero call goes to fake tensor
with fake_mode:
out = meta_index_Tensor(*args, **kwargs)
return out.to(out_device)
# Can take mixed meta/non-meta arguments; the meta registration
# will roughly do the right thing even when given real devices
@register_op_impl(aten._embedding_bag.default)
def embedding_bag(fake_mode, func, *args, **kwargs):
from torch._meta_registrations import meta_embedding_bag
with fake_mode:
return meta_embedding_bag(*args, **kwargs)
# takes in multiple-devices, dont default to default device handling
@register_op_impl(aten._unsafe_index_put.default)
@register_op_impl(aten.copy.default)
@register_op_impl(aten.copy_.default)
@register_op_impl(aten.slice_scatter.default)
def multi_device_op_default(fake_mode, func, *args, **kwargs):
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
# same with multi_device_op_default, but return the input
@register_op_impl(aten.copy.out)
@register_op_impl(aten.slice_scatter.out)
def multi_device_op_out(fake_mode, func, *args, **kwargs):
with in_kernel_invocation_manager(fake_mode):
out = func(*args, **kwargs)
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
return new_kwargs["input"]
@register_op_impl(aten.index_put.default)
@register_op_impl(aten.index_put_.default)
def index_put_impl(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
values = new_kwargs["values"]
self_device = new_kwargs["input"].fake_device
torch._check(
self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
)
out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
if func is aten.index_put_.default:
return new_kwargs["input"]
else:
return out
@register_op_impl(aten._nested_tensor_from_tensor_list.default)
@register_op_impl(aten._nested_tensor_from_tensor_list.out)
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
raise UnsupportedOperatorException(
"torch.compile does not support strided NestedTensor"
)
@register_op_impl(
[
x
for x in _device_not_kwarg_ops
if x
not in (
# these are already registered elsewhere
aten.to.device,
aten.to.prim_Device,
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
)
]
)
def nyi(fake_mode, func, *args, **kwargs):
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
def conv(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
device = kwargs["input"].fake_device
# need to re-enable mode so the tensors report fake device
with fake_mode:
# if the input is unsqueezed is done in Convolution.cpp we get segfault
k = kwargs["weight"].ndim
batch = kwargs["input"].shape[0]
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import has_hint
if not has_hint(batch):
# TODO: We can make this a little more faithful with best effort
# channels last detection (but only if it's statically obvious!)
mem_fmt = None
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
mem_fmt = None
else:
if func is aten.convolution.default:
conv_backend = torch._C._select_conv_backend(**kwargs)
else:
conv_backend = torch._C._select_conv_backend(
kwargs["input"],
kwargs["weight"],
bias=None,
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
transposed=kwargs["transposed"],
output_padding=kwargs["output_padding"],
groups=kwargs["groups"],
bias_sizes=kwargs["bias_sizes"],
)
mem_fmt = torch._C._conv_determine_backend_memory_format(
kwargs["input"], kwargs["weight"], conv_backend
)
def convert(t, mem_fmt):
if t is None:
return t
if mem_fmt is not None:
t = t.to(memory_format=mem_fmt)
return FakeTensor(fake_mode, t, device)
with in_kernel_invocation_manager(fake_mode):
out = func(**kwargs)
if func is aten.convolution.default:
return convert(out, mem_fmt)
else:
return (
convert(out[0], mem_fmt),
convert(out[1], mem_fmt),
convert(out[2], None),
)
@register_op_impl(aten._scaled_dot_product_flash_attention.default)
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
batch_size = query.size(0)
num_heads = query.size(1)
max_seqlen_batch_q = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(2)
query_t = query.transpose(1, 2)
# empty_like already returns a fake tensor so we don't need to convert it
attention = torch.empty_like(query_t).transpose(1, 2)
logsumexp = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device="meta",
),
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device="meta",
),
device=query.device,
)
else:
debug_mask = convert_tensor(
torch.empty(0, dtype=query.dtype, device="meta"),
query.device,
)
# Note [Seed and Offset]: device for seed and offset below depends on whether we are
# capturing or not, but at the time of tracing we don't know if we
# are going to use cudagraphs or not, so we return meta tensors here
# it's possible we'll need to have some special handling in inductor for sdpa
return (
attention,
logsumexp,
None,
None,
max_seqlen_batch_q,
max_seqlen_batch_k,
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
debug_mask,
)
@register_op_impl(aten._scaled_dot_product_efficient_attention.default)
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: attn_bias, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = convert_tensor(
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
query.device,
)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = convert_tensor(
torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device="meta",
),
query.device,
)
res = res.transpose(1, 2)
# See Note [Seed and Offset]:
seed = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
offset = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset
@register_op_impl(aten._flash_attention_forward.default)
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
cum_seq_q = kwargs["cum_seq_q"]
cum_seq_k = kwargs["cum_seq_k"]
max_q = kwargs["max_q"]
max_k = kwargs["max_k"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
# NB: there are two underlying paths:
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
num_heads = query.size(-2)
head_dim = query.size(-1)
# Cuda Path
# note: empty_like already returns a fake tensor, we don't need to wrap it
attention = torch.empty_like(query)
logsumexp = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device="meta",
),
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device="meta",
),
query.device,
)
else:
debug_mask = convert_tensor(
torch.empty(0, dtype=query.dtype, device="meta"),
query.device,
)
# See Note [Seed and Offset]:
return (
attention,
logsumexp,
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
debug_mask,
)
@register_op_impl(aten._efficient_attention_forward.default)
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
cu_seqlens_q = kwargs["cu_seqlens_q"]
max_seqlen_q = kwargs["max_seqlen_q"]
max_seqlen_k = kwargs["max_seqlen_k"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = convert_tensor(
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
query.device,
)
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
actual_max_seqlen_q = M
if cu_seqlens_q is not None:
assert max_seqlen_q is not None
actual_max_seqlen_q = max_seqlen_q
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
logsumexp_dim = (
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
)
logsum_exp = convert_tensor(
torch.empty(
(logsumexp_batch_dim, num_heads, logsumexp_dim),
dtype=torch.float,
device="meta",
),
query.device,
)
# See Note [Seed and Offset]:
seed = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
offset = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
FAST_OP_IMPLEMENTATIONS = {}
# Unlike register_op_impl, these don't do the slow iteration for
# run_impl_check, and these run BEFORE decompositions
def register_fast_op_impl(func: OpOverload):
def impl_decorator(op_impl):
FAST_OP_IMPLEMENTATIONS[func] = op_impl
return op_impl
return impl_decorator
# infer_size_impl in ExpandUtils
def infer_size(a, b):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
dimsA = len(a)
dimsB = len(b)
ndim = max(dimsA, dimsB)
expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else 1
sizeB = b[dimB] if dimB >= 0 else 1
# NB: It is very important to test for broadcasting, before testing
# sizeA == sizeB. This is because the broadcasting tests are likely
# to be statically known (in particular, if sizeA/sizeB is unbacked
# but size-like, we will unsoundly assume they never equal 1), but
# the sizeA == sizeB test may not be statically known. However, once
# we have established that no broadcasting is happening, the
# sizeA == sizeB is now expect_true and we can defer it as a runtime
# assert (this works because Python will return the terminal
# expression of an or statement as-is, without bool()'ing it; if this
# were not the case, we'd need to write this using torch.sym_or() or
# something like that).
torch._check(
guard_size_oblivious(sizeA == 1)
or guard_size_oblivious(sizeB == 1)
or sizeA == sizeB,
lambda: f"The size of tensor a ({sizeA}) "
f"must match the size of tensor b ({sizeB}) "
f"at non-singleton dimension {i})",
)
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
return tuple(expandedSizes)
def make_fast_binary_impl(slow_ref):
def fast_binary_impl(mode, *args, **kwargs):
def slow(msg):
count_label(f"slow {msg}")
with mode:
return slow_ref(*args, **kwargs)
count_label("attempt fast")
# Fast path (based off of TensorIterator fast path).
# Unfortunately, there is no way to easily deduplicate
# this with either the TensorIterator C++ implementation
# (which we don't want to SymIntify, and also the algorithm
# here is slightly different from TensorIterator to allow
# for broadcasting), nor the PrimTorch implementation
# (which does not actually implement a fast path.)
operands = args
# compute_shape
has_scalars = False
has_tensors = False
final_shape = None
for op in operands:
shape = op.shape if isinstance(op, torch.Tensor) else ()
if len(shape) == 0:
has_scalars = True
else:
has_tensors = True
if final_shape is None:
final_shape = shape
# TODO: Minor optimization: track if the shapes
# were equal so you can skip the equality check
# below if unnecessary
final_shape = infer_size(final_shape, shape)
assert final_shape is not None
# Do some extra safety checks to see if the output
# stride is obvious
for op in operands:
if (
isinstance(op, torch.Tensor)
and len(op.shape) == len(final_shape)
and op.shape == final_shape
):
break
else:
return slow("both tensors nontrivially broadcast")
# compute_types
cpu = torch.device("cpu")
common_device = cpu
common_dtype = None
output_dtype = None
has_different_input_dtypes = False
for op in operands:
if not isinstance(op, torch.Tensor):
# Use elementwise_dtypes for the tricky case
has_different_input_dtypes = True
continue
if common_device == cpu and not op.device.type == "cpu":
common_device = op.device
# Slightly simplified here as target_dtype cannot vary
if common_dtype is None:
common_dtype = op.dtype
elif common_dtype != op.dtype:
has_different_input_dtypes = True
if has_different_input_dtypes:
# compute promotion
# TODO: we don't need the compute type
_, common_dtype = elementwise_dtypes(
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# check all tensors on same device
# cpu scalars are assumed allow
current_cpu_scalars_on_non_cpu = 0
max_cpu_scalars_on_non_cpu = 1 # hard coded atm
for op in operands:
if not isinstance(op, torch.Tensor):
continue
if common_device != cpu and op.dim() == 0 and op.device == cpu:
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
return slow("error")
current_cpu_scalars_on_non_cpu += 1
elif op.device != common_device:
return slow("error")
# compute_fast_setup_type
is_contiguous = True
is_channels_last = True
# TODO: is_non-overlapping_and_dense (not bound from Python
# no inplace, no out, everything defined
if is_noncontiguous_supported(common_device):
for op in operands:
if not isinstance(op, torch.Tensor):
continue
is_contiguous = is_contiguous and op.is_contiguous(
memory_format=torch.contiguous_format
)
is_channels_last = is_channels_last and op.is_contiguous(
memory_format=torch.channels_last
)
if is_contiguous:
# do contiguous
count_label("fast is_contiguous")
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.contiguous_format,
),
device=common_device,
)
if is_channels_last:
count_label("fast channels_last")
# do channels last
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.channels_last,
),
device=common_device,
)
return slow("no contiguity match")
return fast_binary_impl
@functools.lru_cache(None)
def get_fast_op_impls():
import torch._refs
register_fast_op_impl(torch.ops.aten.add.Tensor)(
make_fast_binary_impl(torch._refs.add)
)
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
make_fast_binary_impl(torch._refs.sub)
)
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.div.Tensor)(
make_fast_binary_impl(torch._refs.div)
)
return FAST_OP_IMPLEMENTATIONS
|