File size: 102,935 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 |
# mypy: ignore-errors
import abc
import collections
import contextlib
import dataclasses
import enum
import functools
import inspect
import itertools
import logging
import math
import operator
import re
import sys
import types
from typing import Any, List, NamedTuple, Optional, Union
from torch.utils._sympy.value_ranges import ValueRanges
try:
import numpy as np
except ModuleNotFoundError:
np = None
import torch
from torch import SymInt
from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimDynamic,
RelaxedUnspecConstraint,
StatefulSymbolicContext,
SubclassSymbolicContext,
SymbolicContext,
)
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import TensorWeakRef
from .. import config, mutation_guard, replay_record, trace_rules
from ..device_interface import get_registered_device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..side_effects import SideEffects
from ..source import (
AttrSource,
CallMethodItemSource,
ConstantSource,
ConstDictKeySource,
ConvertIntSource,
FloatTensorSource,
GetItemSource,
GradSource,
is_cell_contents,
is_constant_source,
is_from_defaults,
is_from_optimizer_source,
LocalSource,
NumpyTensorSource,
OptimizerSource,
RandomValueSource,
Source,
TupleIteratorGetItemSource,
)
from ..trace_rules import (
is_callable_allowed,
is_numpy,
is_numpy_dtype,
is_numpy_type_info,
)
from ..utils import (
build_checkpoint_variable,
clone_input,
common_constant_types,
get_fake_value,
get_locals_to_steal,
get_static_address_type,
is_function_or_wrapper,
is_namedtuple,
is_typing,
is_utils_checkpoint,
istype,
odict_values,
proxy_args_kwargs,
set_example_value,
tensor_always_has_static_shape,
tuple_iterator,
tuple_iterator_getitem,
tuple_iterator_len,
unwrap_with_attr_name_if_wrapper,
wrap_fake_exception,
)
from .base import MutableLocal, typestr, VariableTracker, VariableTrackerMeta
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
AutocastModeVariable,
EventVariable,
NullContextVariable,
PreserveVersionContextVariable,
StreamContextVariable,
StreamVariable,
)
from .dicts import (
ConstDictVariable,
DataClassVariable,
DefaultDictVariable,
HFPretrainedConfigVariable,
PythonSysModulesVariable,
SetVariable,
)
from .distributed import (
DeviceMeshVariable,
PlacementClassVariable,
PlacementVariable,
ProcessGroupVariable,
WorldMetaClassVariable,
)
from .functions import (
CollectiveFunctionRewriteVariable,
FunctoolsPartialVariable,
TritonKernelVariable,
UserMethodVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .iter import ItertoolsVariable
from .lazy import LazyVariableTracker
from .lists import (
BaseListVariable,
ListVariable,
NamedTupleVariable,
RangeVariable,
RestrictedListSubclassVariable,
SizeVariable,
SliceVariable,
TupleIteratorVariable,
TupleVariable,
)
from .misc import (
AutogradFunctionContextVariable,
AutogradFunctionVariable,
ComptimeVariable,
DebuggingVariable,
DelayGraphBreakVariable,
GetAttrVariable,
GetSetDescriptorVariable,
InspectSignatureVariable,
LambdaVariable,
LoggingLoggerVariable,
MethodWrapperVariable,
NumpyDTypeVariable,
NumpyTypeInfoVariable,
NumpyVariable,
PythonModuleVariable,
RegexPatternVariable,
SavedTensorBox,
TorchVersionVariable,
TypingVariable,
)
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
NumpyNdarrayVariable,
SymNodeVariable,
TensorSubclassVariable,
TensorVariable,
UnspecializedPythonVariable,
)
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable
from .user_defined import (
KeyedJaggedTensorVariable,
SourcelessGraphModuleVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
)
log = logging.getLogger(__name__)
DimList = List
class _missing:
pass
@dataclasses.dataclass
class GraphArg:
source: Source
# TODO: storing a SymInt here but not a FakeTensor is a pretty strange
# thing to do. Probably should have example (which stores an int) and
# fake_example
_example: Union[TensorWeakRef, torch.SymInt]
# When True, this indicates that this GraphArg is a Python quantity (e.g.,
# a float or int) which we pass to the FX graph as a Tensor. This
# controls how we codegen calls into the Dynamo graph: we will call
# torch.as_tensor on the quantity before passing it in.
#
# Note that we typically do not pass dynamic integers as tensors, because
# they will most frequently just be used for size computation. But this
# is a policy decision that we can change our mind on; in particular, when
# an int comes from a random number generator (e.g., random.randint), we
# DO pass it as a tensor.
#
# It's also worth noting that our current tracing rules for
# pass_arg_as_tensor as subtly broken: we just pun the variable as a
# 0d scalar Tensor and pray that the semantics are the same. Which they
# often are, but not necessarily. ezyang(May 2024) plans to fix this
# soon.
pass_arg_as_tensor: bool
fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
# UnspecializedPythonVariable often masquerades as a tensor.
# We MUST NOT generate shape guard code
# that actually tries to access tensor properties on these values.
# is_tensor lets us tell if this graph arg actually is a tensor
# or not.
is_tensor: bool = True
# Sometimes, the Tensor we pass to example is freshly allocated (smh).
# Then we cannot only keep a weak reference to it. This lets you
# stash a strong reference too.
example_strong_ref: Optional[torch.Tensor] = None
@property
def example(self):
if isinstance(self._example, TensorWeakRef):
r = self._example()
assert r is not None
return r
else:
return self._example
def __post_init__(self):
if isinstance(self._example, torch.Tensor):
self._example = TensorWeakRef(self._example)
assert is_fake(self.fake_tensor)
def reconstruct(self, codegen):
self.source.reconstruct(codegen)
def erase(self):
self._example = None
self.example_strong_ref = None
def __eq__(self, other):
return self.source.name() == other.source.name()
class BackwardStateGraphArg(GraphArg):
def __init__(self):
super().__init__(
source=None,
_example=BackwardState(),
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
def reconstruct(self, codegen):
assert codegen.tx.output.backward_state_var
codegen.load_import_from(BackwardState.__module__, "BackwardState")
codegen.call_function(0, True)
codegen.dup_top()
codegen.store(codegen.tx.output.backward_state_var)
@dataclasses.dataclass
class FrameStateSizeEntry:
scalar: Optional[int]
size: Optional[List[int]]
class VariableBuilder:
"""Wrap a python value in a VariableTracker() instance"""
def __init__(
self,
tx,
source: Source,
):
assert (
source is not None
), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."
assert TracingContext.try_get() is not None, "Expected active TracingContext"
super().__init__()
self.tx = tx
self.source = source
self.name = source.name()
def __call__(self, value):
if value in self.tx.output.side_effects:
side_effect_result = self.tx.output.side_effects[value]
dup_guard = make_dupe_guard(self.source, side_effect_result.source)
if dup_guard:
self.install_guards(dup_guard)
return side_effect_result
cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source)
if cached_vt:
return cached_vt
vt = self._wrap(value)
vt.source = self.source
if self._can_lift_attrs_to_inputs(vt):
vt = self.tx.output.side_effects.track_object_existing(value, vt)
self.tx.output.variable_tracker_cache.add(value, self.source, vt)
return vt
def _can_lift_attrs_to_inputs(self, vt):
if type(vt) in [
TensorVariable,
TensorWithTFOverrideVariable,
UserDefinedObjectVariable,
NumpyNdarrayVariable,
]:
return True
return False
@staticmethod
@functools.lru_cache(None)
def _common_constants():
return {
# We zero-one specialize shapes, so specialize these constants
# too
0,
1,
# NB: There used to be more constants here, but honestly it was
# pretty confusing. Note we specialize floats by default, and
# DON'T specialize ints by default. This all only matters with
# dynamic_shapes
}
def get_source(self):
return self.source
def install_guards(self, *guards):
source = self.get_source()
if (
isinstance(source, ConstantSource)
or source.guard_source() == GuardSource.CONSTANT
):
return None
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
return {}
def set_source_and_track_mutable(self, value, var):
assert isinstance(var, VariableTracker)
var.source = self.source
return self.tx.output.side_effects.track_mutable(value, var)
@classmethod
@functools.lru_cache(None)
def _type_dispatch(cls):
# NB: Careful not to close over self to avoid ref cycle from lru_cache
entries = [
(
(
torch.Tensor,
torch.nn.Parameter,
torch._subclasses.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
cls.wrap_tensor,
),
(
(tuple, list, odict_values, collections.deque, torch.Size),
cls.wrap_listlike,
),
(tuple_iterator, cls.wrap_tuple_iterator),
((slice, range), cls.wrap_slice_range),
(tuple(common_constant_types), cls.wrap_literal),
(re.Pattern, cls.wrap_regex_pattern),
]
if config.trace_numpy and np:
entries.append((np.ndarray, cls.wrap_numpy_ndarray))
result = {}
for ts, fn in entries:
for t in ts if isinstance(ts, tuple) else (ts,):
assert t not in result
result[t] = fn
return result
def wrap_regex_pattern(self, value: re.Pattern):
# TODO(jansel): something like a REPR_MATCH might be more robust here
self.install_guards(GuardBuilder.ID_MATCH)
return RegexPatternVariable(value)
@classmethod
@functools.lru_cache(None)
def _id_dispatch(cls):
from ..comptime import comptime
entries = [
(
inspect.signature,
lambda self, value: LambdaVariable(
InspectSignatureVariable.create,
source=self.source,
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
),
),
(comptime, lambda self, value: ComptimeVariable()),
(
dataclasses.fields,
lambda self, value: LambdaVariable(
_dataclasses_fields_lambda,
source=self.source,
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
),
),
(torch.__version__, lambda self, value: TorchVersionVariable()),
]
result = {}
for ts, fn in entries:
for t in ts if isinstance(ts, (tuple, list)) else (ts,):
assert t not in result
result[id(t)] = fn
return result
def _wrap(self, value):
# import here to avoid circular dependencies
from torch.utils._triton import has_triton
if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
else:
class JITFunction:
pass
class Autotuner:
pass
# Handle exact type() match
type_dispatch = self._type_dispatch().get(type(value))
if type_dispatch is not None:
return type_dispatch(self, value)
# Handle exact id() match
id_dispatch = self._id_dispatch().get(id(value))
if id_dispatch is not None:
return id_dispatch(self, value)
# Note - There are some nested values where types mismatch!
# We want to get those out and wrap those.
value = inspect.getattr_static(value, "_torchdynamo_inline", value)
# Everything else (NB: order matters!)
if is_traceable_wrapper_subclass(value) or istype(
value, config.traceable_tensor_subclasses
):
return self.wrap_tensor(value)
elif is_namedtuple(value):
return self.wrap_listlike(value)
elif value is torch.utils._pytree.SUPPORTED_NODES:
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
# under the assumption that the values themselves don't change.
self.install_guards(GuardBuilder.DICT_VERSION)
# The keys on the SUPPORTED_NODES can be arbitrary, so save on the
# key order.
self.tx.output.guard_on_key_order.add(self.source.name())
result = {
ConstantVariable.create(k): UserDefinedObjectVariable(
v,
source=GetItemSource(
self.get_source(), ConstDictKeySource(self.get_source(), i)
),
)
for i, (k, v) in enumerate(value.items())
}
return ConstDictVariable(result, type(value))
elif value is sys.modules:
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonSysModulesVariable(source=self.source)
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
if not value and self.get_source().is_nn_module():
# It is faster to guard on 'false' property than to guard
# on actual dict keys, but we can't do this fast guard in general because
# it omits a crucial type check that ensures the value is actually still a dict at runtime.
# Why is this OK for (specialized) nnmodules? We set up a setattr hook
# to check for module property mutations, which does a reasonable,
# but not completely secure job ensuring a property wasn't changed.
self.install_guards(GuardBuilder.BOOL_FALSE)
else:
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
# Optimisation for the common case strings, ints, etc
all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
if all_const:
# TODO(anijain2305) - Do we have to guard on all the keys? Can
# keys be guarded lazily, similar to values?
self.install_guards(GuardBuilder.DICT_CONST_KEYS)
else:
# Guard on the key order
# This is not ideal, i.e., there is no need to guard on the key
# order. But we guard on the key order because of the complexity
#
# 1) For non-constant objects, we can't save the key in the
# guard context because it can be memory heavy. We can add
# weakrefs but this complicates the accesses.
#
# 2) For non-constant objects, we also have to guard on the keys
# (like TENSOR_MATCH on tensor). We might also have guards on
# the attributes of the keys (like tensor.grad). To make this
# work in tree strucutre is complicated.
#
# So, instead we guard on the key order. While guarding on key
# order, we just save the indices and use it to access keys and
# values. Indices are cheap to save.
self.tx.output.guard_on_key_order.add(self.source.name())
# We need all the keys to be hashable. We do this within the
# _HashableTracker class in dicts.py
def build_key_value(i, k, v):
if all_const:
key = ConstantVariable.create(k)
source_key = k
else:
source_key = ConstDictKeySource(self.get_source(), i)
key = LazyVariableTracker.create(k, source_key)
source_value = GetItemSource(self.get_source(), source_key)
value = LazyVariableTracker.create(v, source_value)
return key, value
result = dict(
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
)
if istype(value, collections.defaultdict):
factory_source = AttrSource(self.source, "default_factory")
result = DefaultDictVariable(
result,
type(value),
default_factory=VariableBuilder(self.tx, factory_source)(
value.default_factory
),
source=self.source,
)
else:
result = ConstDictVariable(result, type(value), source=self.source)
return self.set_source_and_track_mutable(value, result)
elif isinstance(value, torch.nn.Module):
return self.wrap_module(value)
elif ConstantVariable.is_literal(value): # non-atomic literals
return self.wrap_literal(value)
elif istype(value, frozenset) and (
ConstantVariable.is_literal(x) for x in value
):
# For frozenset, we can guard by object ID instead of value
# equality, this allows us to handle non-literal values
self.install_guards(GuardBuilder.ID_MATCH)
return ConstantVariable.create(value=value, source=self.source)
elif isinstance(value, enum.Enum):
self.install_guards(GuardBuilder.ID_MATCH)
return EnumVariable(value=value, source=self.source)
elif DebuggingVariable.is_reorderable_logging_function(value):
# Put this above builtin_callable so that print() can be handled
# along with other builtin debugging functions
self.install_guards(GuardBuilder.BUILTIN_MATCH)
return DebuggingVariable(value, source=self.source)
elif isinstance(value, logging.Logger):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return LoggingLoggerVariable(value, source=self.source)
elif is_utils_checkpoint(value):
return build_checkpoint_variable(source=self.source)
elif isinstance(value, functools.partial):
func_src = AttrSource(self.get_source(), "func")
func_obj = VariableBuilder(self.tx, func_src)(value.func)
args = []
args_source = AttrSource(self.get_source(), "args")
for i, arg in enumerate(value.args):
args.append(
VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
)
keywords = {}
keywords_source = AttrSource(self.get_source(), "keywords")
for k, v in value.keywords.items():
if not ConstantVariable.is_literal(k):
unimplemented("functools.partial with non-literal keyword")
keywords[k] = VariableBuilder(
self.tx, GetItemSource(keywords_source, k)
)(v)
install_guard(
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
keywords_source.make_guard(GuardBuilder.DICT_KEYS),
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
)
return FunctoolsPartialVariable(func_obj, args, keywords)
elif is_typing(value):
# typing.List, typing.Mapping, etc.
self.install_guards(GuardBuilder.ID_MATCH)
return TypingVariable(
value,
source=self.source,
)
elif np is not None and isinstance(value, np.generic):
# numpy array scalars: convert to 0D arrays
return self.wrap_numpy_ndarray(np.asarray(value))
elif is_numpy(value):
assert np
self.install_guards(
GuardBuilder.FUNCTION_MATCH
if callable(value)
else GuardBuilder.TYPE_MATCH
)
return NumpyVariable(value, source=self.source)
elif is_numpy_dtype(value):
self.install_guards(GuardBuilder.ID_MATCH)
return NumpyDTypeVariable(value, source=self.source)
elif is_numpy_type_info(value):
if isinstance(value, np.iinfo):
self.install_guards(GuardBuilder.TYPE_MATCH)
dt_source = AttrSource(self.source, "dtype")
install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH))
else:
self.install_guards(GuardBuilder.ID_MATCH)
return NumpyTypeInfoVariable(value, source=self.source)
# NB: These can't be put in type_dispatch, they have to run later
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return CollectiveFunctionRewriteVariable.create(
self.tx,
value,
source=self.source,
)
elif istype(value, torch.autograd.function.FunctionMeta):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return AutogradFunctionVariable(
value,
source=self.source,
)
elif isinstance(value, torch.autograd.function.FunctionCtx):
actual_saved_tensors = None
try:
actual_saved_tensors = value.saved_tensors
except RuntimeError:
pass
saved_tensors = []
guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)]
if isinstance(actual_saved_tensors, tuple):
saved_tensors_source = AttrSource(self.source, "saved_tensors")
guards.append(
saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)
)
for i, v in enumerate(actual_saved_tensors):
saved_tensors.append(
VariableBuilder(
self.tx, GetItemSource(saved_tensors_source, i)
)(v)
)
install_guard(*guards)
return self.tx.output.side_effects.track_object_existing(
value,
AutogradFunctionContextVariable(
value,
source=self.source,
saved_tensors=SavedTensorBox(saved_tensors),
),
)
elif (
isinstance(value, types.MethodType)
and istype(
getattr(value, "__self__", None), torch.autograd.function.FunctionMeta
)
and getattr(value, "__name__", "") == "apply"
and value == getattr(value.__self__, "apply", None)
):
# handle aliased autograd function `apply` calls
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return GetAttrVariable(
AutogradFunctionVariable(
value.__self__, source=AttrSource(self.source, member="__self__")
),
"apply",
)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True
return trace_rules.lookup_callable(value).create_with_source(
value, source=self.source
)
elif np and isinstance(value, np.number):
return self.wrap_unspecialized_primitive(value)
elif DataClassVariable.is_matching_object(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
return DataClassVariable.wrap(self, value)
elif HFPretrainedConfigVariable.is_matching_object(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
return HFPretrainedConfigVariable(value)
elif isinstance(value, HigherOrderOperator):
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
elif isinstance(value, torch.cuda.StreamContext):
self.install_guards(GuardBuilder.ID_MATCH)
stream_source = AttrSource(self.source, "stream")
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
return StreamContextVariable.create(self.tx, stream_var)
elif isinstance(value, _StreamBase):
self.install_guards(GuardBuilder.ID_MATCH)
stream_proxy = self.tx.output.create_proxy(
"call_function",
torch.cuda.Stream,
(),
{
"stream_id": value.stream_id,
"device_index": value.device_index,
"device_type": value.device_type,
},
)
set_example_value(stream_proxy.node, value)
return StreamVariable(
stream_proxy,
value,
value.device,
source=self.source,
)
elif isinstance(value, (torch._C._SDPAParams)):
self.install_guards(GuardBuilder.TYPE_MATCH)
return SDPAParamsVariable.create(self.tx, value, self.source)
elif isinstance(value, _EventBase):
self.install_guards(GuardBuilder.ID_MATCH)
return EventVariable(
None,
value,
source=self.source,
)
elif (
isinstance(value, torch._C._TensorMeta)
and value in config.traceable_tensor_subclasses
):
return TensorSubclassVariable(value, source=self.source)
elif (
istype(value, contextlib.nullcontext)
and inspect.getattr_static(value, "enter_result", None) is None
):
self.install_guards(GuardBuilder.TYPE_MATCH)
return NullContextVariable(source=self.source)
elif KeyedJaggedTensorVariable.is_matching_object(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = KeyedJaggedTensorVariable(value, source=self.source)
# TODO: this doing it manually is bad
return self.tx.output.side_effects.track_object_existing(value, result)
elif isinstance(value, torch.optim.Optimizer):
self.install_guards(GuardBuilder.ID_MATCH)
self.source = OptimizerSource(self.source)
return OptimizerVariable(value, source=self.source)
elif WorldMetaClassVariable.is_group_member_type(value):
return WorldMetaClassVariable(value, source=self.source)
elif ProcessGroupVariable.is_process_group(value):
self.install_guards(GuardBuilder.ID_MATCH)
return ProcessGroupVariable(value, source=self.source)
elif DeviceMeshVariable.is_device_mesh(value):
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
self.install_guards(GuardBuilder.ID_MATCH)
return DeviceMeshVariable(value, source=self.source)
elif PlacementClassVariable.is_placement_type(value):
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
self.install_guards(GuardBuilder.ID_MATCH)
return PlacementClassVariable(value, source=self.source)
elif PlacementVariable.is_placement(value):
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
self.install_guards(GuardBuilder.ID_MATCH)
return PlacementVariable(
value,
source=self.source,
)
elif istype(value, type) and value in itertools.__dict__.values():
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return ItertoolsVariable(value, source=self.source)
elif isinstance(value, torch.SymBool):
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
# user provided SymBool with a SymInt in dynamo.
# Concretely,
# 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
# so that guards on the SymInts can be effectively applied on the original SymBool in user program.
# 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
# depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
value_hint = value.node.require_hint()
new_source = ConvertIntSource(self.source)
new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol(
int(value_hint),
new_source,
dynamic_dim=DimDynamic.DYNAMIC,
)
sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(new_symint),
source=new_source,
)
sym_node_proxy.node.meta["grapharg"] = GraphArg(
new_source,
new_symint,
False,
None,
is_tensor=False,
example_strong_ref=new_symint,
)
self.tx.output.bound_symbols.add(new_symint.node.expr)
self.tx.output.tracked_fakes.append(
TrackedFake(new_symint, new_source, None)
)
return SymNodeVariable(
sym_node_proxy,
new_symint == 1,
)
elif isinstance(value, (JITFunction, Autotuner)):
self.install_guards(GuardBuilder.ID_MATCH)
return TritonKernelVariable(
value,
None, # No kernel idx provided
None, # No grid provided
source=self.source,
)
elif isinstance(value, torch.amp.autocast_mode.autocast):
self.install_guards(GuardBuilder.ID_MATCH)
return AutocastModeVariable(
target_values=[
value.device,
value.fast_dtype,
value._enabled,
value._cache_enabled,
],
source=self.source,
)
elif TorchCtxManagerClassVariable.is_matching_cls(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchCtxManagerClassVariable(value, source=self.source)
elif is_function_or_wrapper(value):
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
# For these wrappers, Dynamo points to the wrapped function,
# so source needs to be updated as well.
if attr_name is not None:
self.source = AttrSource(self.source, attr_name)
return trace_rules.lookup(value).create_with_source(
value, source=self.source
)
# Don't use istype, since some python modules are not subclasses of types.ModuleType directly.
# E.g, type(torch.ops) -> <class 'torch._ops._Ops'>,
# type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonModuleVariable(
value,
source=self.source,
)
elif isinstance(value, types.MethodType) and isinstance(
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
):
# don't let MethodTypes fall through to UserDefinedObject,
# which doesn't support 'CALL_FUNCTION'
# TODO(whc): Why do we limit this to methods on NNModules?
# I don't have a good reason for this, but it preserves the existing behavior
# for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise.
# I suspect we probably want to relax this check and dig deeper there.
# In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python,
# but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here
# and then `__func__` gets wrapped inside UserMethodVariable.
self_obj = VariableBuilder(
self.tx, source=AttrSource(self.source, "__self__")
)(value.__self__)
assert self_obj and isinstance(
self_obj, VariableTracker
), "Failed to produce a valid self obj"
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserMethodVariable(
value.__func__,
self_obj,
source=self.source,
)
elif isinstance(value, types.GetSetDescriptorType):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return GetSetDescriptorVariable(value)
elif isinstance(value, types.MethodWrapperType):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return MethodWrapperVariable(value)
elif issubclass(type(value), type):
if value in (torch.utils.hooks.BackwardHook, torch.nn.Parameter):
# TODO(jansel): combine this case with the one above
return trace_rules.lookup(value).create_with_source(
value, source=self.source
)
if value is torch.autograd._unsafe_preserve_version_counter:
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PreserveVersionContextVariable.constructor(self.tx)
# This is a userdefined class, so install an ID_MATCH even if its a
# global variable.
self.install_guards(GuardBuilder.ID_MATCH)
return UserDefinedClassVariable(
value,
source=self.source,
)
elif RestrictedListSubclassVariable.is_matching_cls(type(value)):
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
return self.set_source_and_track_mutable(
value,
RestrictedListSubclassVariable(
[
LazyVariableTracker.create(
value=value[i], source=GetItemSource(self.source, i)
)
for i in range(len(value))
],
user_cls=type(value),
user_cls_source=AttrSource(self.source, "__class__"),
),
)
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
from ..source import (
FlattenScriptObjectSource,
ScriptObjectQualifiedNameSource,
)
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
if not hasattr(value, "__obj_flatten__"):
return self.wrap_user_defined(value)
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
value._type().qualified_name() # type: ignore[attr-defined]
)
)
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
)
)
fake_script_obj = torch._library.fake_class_registry.to_fake_obj(
self.tx.output.fake_mode, value
)
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(value),
source=self.source,
)
# setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
# seting example to be real value because these example values will be used
# as example_inputs for user compiler.
proxy.node.meta["grapharg"] = GraphArg(
self.source, value, False, None, False, fake_script_obj
)
return TorchScriptObjectVariable.create(
proxy,
fake_script_obj,
source=self.source,
)
else:
return self.wrap_user_defined(value)
def wrap_user_defined(self, value: Any):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UserDefinedObjectVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(value, result)
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
if config.specialize_int and type(value) is torch.Size:
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
# One can index a tensor with a list/tuple. Therefore, we need to
# have a stricter match.
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
for item in value:
if item is value:
unimplemented("list elements are pointing to the list itself")
output = [
LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i))
for i, item in enumerate(value)
]
maybe_gm = self.tx.output.local_scope.get("self")
if isinstance(
self.source, LocalSource
) and self.source.local_name in get_locals_to_steal(maybe_gm):
# The input tensor list to dynamo from compiled autograd may contain activations
# which are freed as they are used in inductor. Dynamo's default behavior is to
# lift all tensors to the graph inputs, but this will cause dynamo to hold an
# extra reference to the activation tensors and increase peak memory usage.
# To allow freeing ASAP, we keep the list as graph argument to the dynamo output
# graph, and unpack it locally.
# e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have
# `def forward(self, L_inputs_):`
source = self.source
assert isinstance(value, list)
tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
)
tensor_list_proxy.node.meta["steal_arg"] = True
list_variable = wrap_fx_proxy_cls(
target_cls=TensorVariable,
tx=self.tx,
proxy=tensor_list_proxy,
example_value=value,
subclass_type=None,
source=source,
)
guards = []
for i, tensor_variable in enumerate(list_variable.items):
source_i = GetItemSource(base=source, index=i, index_is_slice=False)
# access unpacked tensor from this list instead of from a lifted arg
self.tx.output.input_source_to_var[source_i] = tensor_variable
guard = functools.partial(
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
)
guards.append(source_i.make_guard(guard))
install_guard(*guards, skip=1)
grapharg = GraphArg(
source,
value,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
tensor_list_proxy.node.meta["grapharg"] = grapharg
result = BaseListVariable.cls_for_instance(value)(
output, mutable_local=MutableLocal()
)
if istype(value, list):
return self.set_source_and_track_mutable(value, result)
return result
def wrap_tuple_iterator(self, value: tuple_iterator):
self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
output = [
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
tuple_iterator_getitem(value, i)
)
for i in range(tuple_iterator_len(value))
]
result = TupleIteratorVariable(
output, mutable_local=MutableLocal(), source=self.source
)
return self.set_source_and_track_mutable(value, result)
def wrap_slice_range(self, value: Union[slice, range]):
items = [
VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
getattr(value, k)
)
for k in ("start", "stop", "step")
]
self.install_guards(GuardBuilder.TYPE_MATCH)
if isinstance(value, slice):
return SliceVariable(items, source=self.source)
else:
return RangeVariable(items, source=self.source)
def wrap_module(self, value: torch.nn.Module):
from ..eval_frame import OptimizedModule
if len(value.__dict__) == 0:
unimplemented(f"uninitialized nn.Module: {typestr(value)}")
if istype(value, OptimizedModule):
# Check if the optimized module was disabled
if inspect.getattr_static(value.forward, "_torchdynamo_disable", False):
# This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If
# we graph break here, Dynamo does not know how to create
# continuation functions for such bytecodes. So, we delay the
# graph break to CALL_FUNCTION.
return DelayGraphBreakVariable(source=self.source)
self.install_guards(GuardBuilder.TYPE_MATCH)
self.source = AttrSource(self.source, "_orig_mod")
return self.wrap_module(value._orig_mod)
if (
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
and not config.allow_rnn
):
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
# created dynamically, don't specialize on it
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UnspecializedNNModuleVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(value, result)
elif issubclass(
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
):
self.install_guards(GuardBuilder.TYPE_MATCH)
return UnspecializedNNModuleVariable(value)
elif getattr(value, "_is_fsdp_managed_module", False):
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
# in fully_sharded_data_parallel.py for more information
# we can't do this assert inside FSDP constructor,
# since we don't know yet whether dynamo will be used
assert getattr(
value, "_fsdp_use_orig_params", False
), "Dynamo only supports FSDP with use_orig_params=True"
# Note on FSDP guarding
# 1. We expect FSDP wrapping mutates an nn module irreversably (no way to de-wrap).
# 2. Eager FSDP already assumes (requires, but without enforcement) that users don't mutate their
# model parameters/structure after FSDP wrapping, because FSDP wouldn't notice or update its FlatParams.
#
# Due to (1), once we enter this path we expect not to go back nor have to guard on type
# or _is_fsdp_managed_module.
#
# TODO(whc) We could add a guard on the opposite case, where a user compiled/ran
# pre-FSDP-wrapped model, then wrapped, to ensure that we recompile with the FSDP handling.
#
# Due to (2), we skip guards on inner contents of fsdp_managed modules, by using FSDPNNModuleSource as the
# guard source. This behavior is gated on config.skip_fsdp_guards.
#
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
return FSDPManagedNNModuleVariable(value, source=self.get_source())
else:
return self.tx.output.register_attr_or_module(
value,
self.name,
source=self.get_source(),
# Guards are added inside register_attr_or_module
)
def wrap_literal(self, value):
if not config.specialize_int and type(value) is int:
# unspecializing int by default, but still
# specialize for the following conditions
if not TracingContext.get().force_unspec_int_unbacked_size_like and (
value in self._common_constants()
# Assume integers from global variables want to be specialized
or not self.source.guard_source().is_local()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
else:
return self.wrap_symint(value)
elif not config.specialize_float and type(value) is float:
return self.wrap_symfloat(value)
else:
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
raise InternalTorchDynamoError(
"Cannot wrap a Tensor that has already been",
"wrapped by this instance of Dynamo",
)
def wrap_tensor(self, value: torch.Tensor):
source = self.get_source()
# We cannot already be tracking the tensor, which implies
# it would have already been wrapped
assert value not in self.tx.output.side_effects
if (
source.guard_source().is_nn_module()
or get_static_address_type(value) is not None
) and not source.guard_source().is_fsdp_module():
self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module(
value, self.name, source=source
)
if is_constant_source(source):
self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module(
value,
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
source=source,
# Guards are added inside register_attr_or_module
)
if type(value) in config.traceable_tensor_subclasses:
# Ordinarily, we would fakeify a tensor so that it can get dynamic
# shapes and be computed on without triggering actual operations.
# However, how can we fakeify a tensor subclass? Ordinary
# inheritance (nor multiple inheritance) won't work work.
#
# Instead, our plan is to *manually simulate* the tensor subclass
# inheriting from a fake tensor with dynamo. This means our
# data representation for a tensor subclass will be a fake tensor
# + tensor subclass type + any extra data the subclass may have
# been storing on the tensor. Because all Python accesses are
# mediated through TensorWithTFOverrideVariable, we can ensure
# that we dispatch differently, e.g., according to
# __torch_function__
#
# To simplify things for now, the __dict__ tracking bits haven't
# been implemented yet, but they can be added into this design at
# a later point in time.
subclass_type = type(value)
else:
assert type(value) in (
torch.Tensor,
torch.nn.Parameter,
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
) or is_traceable_wrapper_subclass(value), type(value)
subclass_type = None
# NB: this just says we accessed a tensor from the same source again
# (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice).
# This is distinct from two distinct sources mapping to the same
# Tensor (per id())! No guard is necessary here. See below for the
# other case.
is_duplicate_tensor = source in self.tx.output.input_source_to_var
if is_duplicate_tensor:
return self.tx.output.input_source_to_var[source]
# By this point, we should have deduplicated all tensors
self.assert_not_wrapped_by_this_graph(value)
# tx.output has multiple tracers if we're introspecting HigherOrderOperator.
# When we've discovered an untracked tensor, then we actually need
# to get Dynamo to track the tensor (which is what this function does)
# and put it as a graph input on the root tracer. Later on,
# if the input is actually used in the body of the HigherOrderOperator,
# then the relevant SubgraphTracer will lift it to being an input of
# the subgraph.
# See NOTE [HigherOrderOperator tracing design] for more details.
tensor_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
)
options = {}
if type(value) in config.traceable_tensor_subclasses:
options["torch_function_fn"] = build_torch_function_fn(
self.tx, value, self.source
)
self.install_guards(GuardBuilder.TYPE_MATCH)
if (
isinstance(value, torch.Tensor)
and value.is_nested
and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor)
):
unimplemented("torch.compile does not support strided NestedTensor")
# Reject sparse, but not coo.
# TODO: remove this altogether when non-coo sparsity propagation is ready
if is_sparse_any(value) and not value.is_sparse:
unimplemented(
f"torch.compile does not support sparse Tensor with {value.layout} layout"
)
tensor_variable = wrap_fx_proxy(
tx=self.tx,
proxy=tensor_proxy,
example_value=value,
subclass_type=subclass_type,
source=source,
**options,
)
guard_type = GuardBuilder.TENSOR_MATCH
if isinstance(source, GradSource) and is_from_optimizer_source(source):
guard_type = GuardBuilder.NOT_NONE_MATCH
self.install_guards(
functools.partial(
guard_type,
value=value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value),
)
)
# We install TYPE_MATCH guards for traceable wrapper subclass object,
# and recursively install corresponding guard for each inner attribute.
if is_traceable_wrapper_subclass(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
attrs, _ = value.__tensor_flatten__()
for attr in attrs:
inner_value = getattr(value, attr)
inner_source = AttrSource(self.source, attr)
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, inner_source)(inner_value)
)
self.tx.output.input_source_to_var[source] = tensor_variable
assert "tensor_dict" not in tensor_proxy.node.meta
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
# Note: this information is conveyed via subclass_type now
fake_tensor_value = tensor_variable.proxy.node.meta["example_value"]
if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode:
raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake")
grapharg = GraphArg(source, value, False, fake_tensor_value)
tensor_proxy.node.meta["grapharg"] = grapharg
self.tx.output.add_symbol_bindings(grapharg)
return tensor_variable
def wrap_numpy_ndarray(self, value):
assert np is not None
assert isinstance(value, np.ndarray)
source = NumpyTensorSource(self.get_source())
from torch._numpy import _util
readonly = not value.flags.writeable
if readonly:
try:
value.flags.writeable = True
except ValueError:
# One can not easily make nditer elements writable,
# but warning is not the end of the world
assert isinstance(value.base, np.nditer)
pass
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
# that there's not another great way to do this atm.
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
)
options = {"source": source}
numpy_ndarray_variable = wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=self.tx,
proxy=proxy,
example_value=tensor_value,
**options,
)
self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]
# pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be
# converted to a tensor.
grapharg = GraphArg(
source,
tensor_value,
pass_arg_as_tensor=True,
fake_tensor=example_value,
is_tensor=True,
example_strong_ref=tensor_value,
)
proxy.node.meta["grapharg"] = grapharg
return numpy_ndarray_variable
def wrap_symint(self, value):
assert type(value) is int
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
shape_env = self.tx.output.shape_env
if TracingContext.get().force_unspec_int_unbacked_size_like:
wrapped_value = shape_env.create_unbacked_symint()
_constrain_range_for_size(wrapped_value)
self.tx.output.bound_symbols.add(wrapped_value.node.expr)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)
# NB: We do not do float. For motivation, see
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
elif not is_constant_source(self.get_source()):
if torch._dynamo.config.specialize_int:
# If specialize_int is False, also return
# a constant (but this should have been handled
# in the caller, TBH)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
name = self.source.name()
if name not in self.tx.output.frame_state:
# Note - this essentially means that if this name gets reused as a tensor,
# it will start fully dynamic. That should always be a safe option, and not awfully inefficient.
# Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not
# sure that is necessary for now.
frame_state_entry = FrameStateSizeEntry(scalar=value, size=None)
else:
frame_state_entry = self.tx.output.frame_state[name]
if frame_state_entry.scalar != value:
log.debug(
"automatic dynamic int %s val %s != %s",
name,
value,
frame_state_entry.scalar,
)
frame_state_entry.scalar = None
self.tx.output.frame_state[name] = frame_state_entry
# TODO: This should be dynamic, as we in general do not
# know if bare integers are actually going to be sizevars
# and it is inappropriate to eagerly duck size them with
# real sizevars
if (
config.automatic_dynamic_shapes and frame_state_entry.scalar is None
) or not config.assume_static_by_default:
dynamic_dim = DimDynamic.DYNAMIC
else: # assume_static_by_default
# TODO: dynamic_dim = DimDynamic.STATIC should work but
# for some reason it doesn't
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
wrapped_value = shape_env.create_unspecified_symint_and_symbol(
value,
source=self.source,
dynamic_dim=dynamic_dim,
)
self.tx.output.bound_symbols.add(wrapped_value.node.expr)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)
else:
assert is_constant_source(self.get_source())
# TODO: Do I actually need guard for constant source?
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
assert not isinstance(self.get_source(), RandomValueSource)
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
options = {"source": self.get_source()}
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
source=self.get_source(),
)
set_example_value(proxy.node, wrapped_value)
unspec_var = SymNodeVariable(proxy, wrapped_value, **options)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if not is_constant_source(self.get_source()):
if self.tx.export and not isinstance(self.get_source(), LocalSource):
raise AssertionError(
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
)
example_value = unspec_var.proxy.node.meta["example_value"]
proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
example_strong_ref=wrapped_value,
)
return unspec_var
def wrap_symfloat(self, value):
# SymFloat wrapping is special. We first wrap it in the same way we
# do an unspecialized primitive, and then we item() it into a
# SymFloat. Removal of the item() call is left to a later FX pass,
# mostly because that pass is more easily done after we have lowered
# to ATen ops. (Dynamo doesn't do decomposition right now).
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
# NB: we specialize on nan input, because our guard modeling in
# ShapeEnv cannot deal with nan
if (
torch._dynamo.config.specialize_float
or is_constant_source(self.get_source())
or math.isnan(value)
):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
# NB: At the point we've gotten here, we don't assume static by
# default. Since we have a guard mechanism, there isn't really any
# downside to trying to be dynamic for float all the time. Unlike
# ints, this won't make codegen perf worse. Modest cost to compile
# time.
wrapped_value = torch.tensor(value, dtype=torch.float64)
# TODO: Switch RandomValueSource over to use this, this is more
# accurate
assert not isinstance(self.get_source(), RandomValueSource)
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
# The FloatTensorSource here is just for pedantic correctness: if you
# guard against an UnspecializedPythonVariable, you need to guard
# against the tensor-ified version of the local, otherwise it's not a
# Tensor. However, we never let the UnspecializedPythonVariable escape
# here, so there should never actually be any guards against this
# source.
options = {"source": FloatTensorSource(self.get_source()), "raw_value": value}
# TODO: Maybe the tensor-ification should be built into the source,
# rather than by special pattern match
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
source=self.get_source(),
)
unspec_var = wrap_fx_proxy_cls(
UnspecializedPythonVariable,
tx=self.tx,
proxy=proxy,
example_value=wrapped_value,
**options,
)
assert isinstance(unspec_var, UnspecializedPythonVariable)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if self.tx.export and not isinstance(self.get_source(), LocalSource):
raise AssertionError(
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
)
fake_tensor_value = None
example_value = unspec_var.proxy.node.meta["example_value"]
assert is_fake(example_value)
fake_tensor_value = example_value
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
"({self.tx.fake_mode}) from InstructionTranslator"
)
# There's something a bit incoherent about pass_arg_as_tensor,
# specifically regarding sources.
#
# Specifically, suppose we have "x: float" local argument. We
# eventually end up with an UnspecializedPythonVariable denoting
# torch.as_tensor(x)... but it's source is still L['x'] (which if you
# accessed it directly is a float!) So you gotta be careful when
# setting up your guards, because it's still going to be a float at
# this point, the conversion happens only precisely at the point we're
# actually calling the FX graph. This happens to be what we want for
# shape guard generation, but it's kind of unintuitive.
proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
pass_arg_as_tensor=True,
fake_tensor=fake_tensor_value,
is_tensor=False,
example_strong_ref=wrapped_value,
)
# Directly do item to bypass capture_scalar_outputs
r = wrap_fx_proxy(
self.tx,
self.tx.output.create_proxy(
"call_method",
"item",
*proxy_args_kwargs([unspec_var], {}),
),
)
self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None))
return r
def wrap_unspecialized_primitive(self, value):
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
wrapped_value = torch.tensor(value)
if not isinstance(self.get_source(), RandomValueSource):
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
options = {"source": self.get_source()}
options.update({"raw_value": value})
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(wrapped_value),
source=self.get_source(),
)
unspec_var = wrap_fx_proxy_cls(
UnspecializedPythonVariable,
tx=self.tx,
proxy=proxy,
example_value=wrapped_value,
**options,
)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if not is_constant_source(self.get_source()):
if self.tx.export and not isinstance(self.get_source(), LocalSource):
raise AssertionError(
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
)
fake_tensor_value = None
if isinstance(unspec_var, ConstantVariable):
# TODO: when can this happen?
example_value = unspec_var.value
else:
example_value = unspec_var.proxy.node.meta["example_value"]
assert is_fake(example_value)
fake_tensor_value = example_value
assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
"({self.tx.fake_mode}) from InstructionTranslator"
)
proxy.node.meta["grapharg"] = GraphArg(
self.get_source(),
wrapped_value,
pass_arg_as_tensor=True,
fake_tensor=fake_tensor_value,
is_tensor=False,
example_strong_ref=wrapped_value,
)
return unspec_var
def _dataclasses_fields_lambda(obj):
if isinstance(obj, UserDefinedObjectVariable):
value = obj.value
elif isinstance(obj, DataClassVariable):
value = obj.user_cls
else:
unimplemented(f"Dataclass fields handling fails for type {obj}")
items = []
for field in dataclasses.fields(value):
source = None
if obj.source:
source = GetItemSource(
AttrSource(obj.source, "__dataclass_fields__"), field.name
)
items.append(UserDefinedObjectVariable(field, source=source))
return TupleVariable(items)
def wrap_fx_proxy(
tx, proxy, example_value=None, subclass_type=None, **options
) -> VariableTracker:
kwargs = {
"tx": tx,
"proxy": proxy,
"example_value": example_value,
"subclass_type": subclass_type,
**options,
}
if subclass_type is None:
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
else:
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
result.install_global(tx)
return result
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
# Should be compositional instead
#
# This is a horribly complicated function that does too many things, to
# explain what it does, let's first talk about the classic usage wrap_fx_proxy
# for a TensorVariable. There are two primary modes of use:
#
# 1. Wrapping a pre-existing Tensor. In this case, example_value is set
# to the pre-existing Tensor. (Note that this example_value will NOT
# be the final example_value we put into node.meta['example_value'],
# instead it is converted into a fake tensor using
# wrap_to_fake_tensor_and_record and registered as a graph input.)
#
# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In
# this case, example_value is None (and we are going to figure it out
# ourselves using FakeTensors, via get_fake_value, which will run
# the operation represented by the (singular!) FX node referenced by
# the passed in proxy.)
#
# The expectation is you end up with a Tensor output, and everything is
# straightforwardly traced into the graph.
#
# In all cases, the returned `TensorVariable` subclass will have an `example_value`
# and that `example_value` must be a `FakeTensor` produced by the currently running
# instance of Dynamo.
#
# Upon closer inspection, you may notice that there are a slurry of non-Tensor
# output cases. What gives? Well, we sometimes trace operations into the
# graph that don't involve tensors.
#
# * Some operators return tuples; we need to recursively handle their
# contents
#
# * Some operators have side effects that will affect subsequent AOTAutograd
# tracing but don't otherwise return anything.
#
# * Some operators return symbolic ints/floats/bools which can go in the
# graph and be traced (but only if they're actually symbolic! If they're
# static you don't want to put them in the graph, which means you
# shouldn't call this function.)
#
# The common theme is that you only use this function WHEN YOU ARE TRACING
# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call
# this function without a proxy.
def wrap_fx_proxy_cls(
target_cls, tx, proxy, example_value=None, subclass_type=None, **options
):
from ..symbolic_convert import InstructionTranslatorBase
assert isinstance(tx, InstructionTranslatorBase)
if "guards" in options and options["guards"] is not None:
tx.output.guards.update(options["guards"])
assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
initial_example_value = example_value
def _clone_input(value):
if isinstance(value, torch.Tensor):
# tensor subclasses will not be converted to FakeTensors and need to be cloned
if not (
isinstance(value, FakeTensor)
or (
# Is functional tensor fakeified by this instance of Dynamo
torch._is_functional_tensor(value)
and maybe_get_fake_mode(value) is tx.fake_mode
)
or value.is_nested
):
# NB: ensure strides are preserved
value = clone_input(value)
return value
# with preserve_rng_state():
if example_value is None:
# only allow_non_graph_fake in this instance because we handle the non-fake
# cases properly below.
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
# Handle recursive calls here
elif maybe_get_fake_mode(example_value) is tx.fake_mode:
pass
elif isinstance(example_value, torch.Tensor):
if tx.export:
# The legacy behavior for real value cache with subclasses was
# to perform a clone WITHOUT preserving the subclass. It's
# not entirely clear this is what you actually want though.
with torch._C.DisableTorchFunctionSubclass():
proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
# NB: If we're ignoring subclass, then the expectation is you will
# take the returned TensorVariable and wrap it into a more
# accurate TensorVariable that is able to track subclass-ness;
# otherwise this is wrong!
kwargs = {
"is_tensor": target_cls in (TensorVariable, TensorWithTFOverrideVariable),
}
assert "source" in options and options["source"] is not None
kwargs["source"] = options["source"]
example_value = wrap_to_fake_tensor_and_record(example_value, tx=tx, **kwargs)
if isinstance(example_value, torch.Tensor) and (
maybe_get_fake_mode(example_value) is not tx.fake_mode
):
raise InternalTorchDynamoError(
"`example_value` needs to be a `FakeTensor`"
f"wrapped by this instance of Dynamo. Found: {example_value}"
)
if isinstance(example_value, torch.Tensor):
is_parameter = isinstance(example_value, torch.nn.Parameter)
# NB: In most (all?) cases, this does not actually do a clone.
# (WARNING: this means that if we mutate metadata on the fake
# tensor, the stored example value will update too!)
example_value = _clone_input(example_value)
set_example_value(proxy.node, example_value)
specialized_props = target_cls.specialize(example_value)
# TODO: not sure about this fake mode test
if (
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
and example_value.fake_mode is tx.fake_mode
):
tensor_type = subclass_type if subclass_type else torch.Tensor
specialized_props["class_type"] = (
torch.nn.Parameter if is_parameter else tensor_type
)
options.update(specialized_props)
return target_cls(proxy, **options)
elif (
hasattr(proxy.node.target, "__name__")
and proxy.node.target.__name__ == "set_state"
and isinstance(proxy.node.target.__self__, torch._C.Generator)
or proxy.node.target == torch.random.set_rng_state
):
return TorchInGraphFunctionVariable(proxy.node.target)
elif (
proxy.node.target == torch._C._DisableFuncTorch
or proxy.node.target == torch.cuda._is_in_bad_fork
):
return UserDefinedObjectVariable(example_value)
elif istype(example_value, torch.Size) and all(
isinstance(x, int) for x in example_value
):
sizes = [ConstantVariable.create(x) for x in example_value]
return SizeVariable(sizes, **options)
elif isinstance(example_value, (tuple, list)):
set_example_value(proxy.node, example_value)
unpacked = []
for i, val in enumerate(example_value):
if val is None:
# nn.MultiheadAttention() can return None, see issue #175
unpacked.append(
ConstantVariable.create(None, **options),
)
else:
proxy_i = proxy.tracer.create_proxy(
kind="call_function",
target=operator.getitem,
args=(proxy, i),
kwargs={},
)
if "source" in options:
source = options["source"]
options_i = options.copy()
options_i["source"] = GetItemSource(
base=source, index=i, index_is_slice=False
)
else:
# use the same options object as parent
options_i = options
# WARNING: this assumes the same target_cls as this tuple/list call
unpacked.append(
wrap_fx_proxy_cls(
target_cls=target_cls,
tx=tx,
proxy=proxy_i,
example_value=val,
**options_i,
)
)
if isinstance(example_value, torch.Size):
# NB: Keep the old proxy around. See SizeVariable for an
# explanation why
return SizeVariable(unpacked, proxy, **options)
elif istype(example_value, tuple):
return TupleVariable(unpacked, **options)
elif istype(example_value, (list, immutable_list)):
return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
else:
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
example_value, "_fields"
), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
return NamedTupleVariable(unpacked, example_value.__class__, **options)
elif example_value is None or proxy.node.target is torch.manual_seed:
return ConstantVariable.create(None, **options)
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
set_example_value(proxy.node, example_value)
return SymNodeVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, _StreamBase)
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, example_value.device, **options)
elif (
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
) or proxy.node.target in [
device_interface.Event
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, **options)
elif proxy.node.target == "query" and proxy.node.op == "call_method":
set_example_value(proxy.node, example_value)
return ConstantVariable(example_value, **options)
elif (
example_value is not None
and isinstance(example_value, _EventBase)
and proxy.node.target == "record_event"
and proxy.node.op == "call_method"
):
set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, **options)
elif isinstance(example_value, int) and proxy.node.target in [
torch.sym_int,
getattr,
operator.getitem,
torch._utils._element_size,
torch.seed,
operator.mod,
torch._functorch.vmap._validate_and_get_batch_size,
# some mac builds are missing torch.distributed.get_rank()
getattr(torch.distributed, "get_rank", _missing),
getattr(torch.distributed, "get_world_size", _missing),
# This always wants to be in the graph, even if the constraint
# results in a constant int
torch._constrain_as_size,
]:
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
elif isinstance(example_value, torch.backends.cuda.SDPAParams):
from .sdpa import SDPAParamsVariable
set_example_value(proxy.node, example_value)
return SDPAParamsVariable(proxy, **options)
elif isinstance(example_value, bool) and proxy.node.target in [
torch.backends.cuda.can_use_flash_attention,
torch.backends.cuda.can_use_efficient_attention,
]:
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
elif (
isinstance(example_value, (int, float, bool))
and proxy.node.target is call_torchbind
):
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
else:
unimplemented(
"torch.* op returned non-Tensor "
+ f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
)
# Tracks the sources of all fake tensors we wrap in Dynamo.
# Used by shape guard computation.
@dataclasses.dataclass
class TrackedFake:
fake: Union[FakeTensor, SymInt]
source: Source
# Is None when fake is SymInt
symbolic_context: Optional[SymbolicContext]
def __hash__(self) -> int:
return hash((self.fake, self.source.name()))
def __eq__(self, other: object) -> bool:
if isinstance(other, TrackedFake):
return self.fake is other.fake and self.source.name() == other.source.name()
return False
# Performs automatic dynamic dim determination.
# Returns a SymbolicContext
def _automatic_dynamic(
e, tx, source, static_shapes, outer_only=False
) -> SymbolicContext:
# strided NT not supported
if e.is_nested and not isinstance(
e, torch.nested._internal.nested_tensor.NestedTensor
):
unimplemented("torch.compile does not support strided NestedTensor")
name = source.name()
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
shape_env_to_source_to_symbol_cache = (
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
)
# Get base context if the tensor is a view
view_base_context: Optional[SymbolicContext] = None
if e._is_view():
base_source = AttrSource(source, "_base")
view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes)
if is_traceable_wrapper_subclass(e) and not outer_only:
# Get symbolic context for outer tensor
outer_context = _automatic_dynamic(
e, tx, source, static_shapes, outer_only=True
)
# Get symbolic contexts for inner tensors
attrs, _ = type(e).__tensor_flatten__(e)
inner_contexts = {} # mapping from attr -> symbolic context
for attr in attrs:
inner_tensor = getattr(e, attr)
inner_source = AttrSource(source, attr)
inner_context = _automatic_dynamic(
inner_tensor, tx, inner_source, static_shapes
)
inner_contexts[attr] = inner_context
return SubclassSymbolicContext(
dynamic_sizes=outer_context.dynamic_sizes,
constraint_sizes=outer_context.constraint_sizes,
view_base_context=view_base_context,
tensor_source=outer_context.tensor_source,
shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache,
inner_contexts=inner_contexts,
)
if static_shapes:
return StatefulSymbolicContext(
dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
constraint_sizes=[None] * e.dim(),
view_base_context=view_base_context,
tensor_source=source,
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
)
# We preserve the dynamism of inputs. For example, when users call
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
from torch.fx.experimental.symbolic_shapes import is_nested_int
if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
return StatefulSymbolicContext(
dynamic_sizes=[
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
for s in e.size()
],
constraint_sizes=[None] * e.dim(),
view_base_context=view_base_context,
tensor_source=source,
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
)
# Prep for automatic dynamic
frame_state_entry = None
if name not in tx.output.frame_state:
# If there is no entry for this source, add the tensor to frame state with its current static size.
# E.g., {} -> {"x": [2, 4]}
frame_state_entry = FrameStateSizeEntry(None, None)
frame_state_entry.size = list(e.size())
else:
frame_state_entry = tx.output.frame_state[name]
if frame_state_entry.size is not None:
if e.ndim != len(frame_state_entry.size):
# If there is already an entry, and the dim mismatches, replace the frame state entry with None.
# E.g. {"x": [2, 3, 4]} -> {"x": None}
log.debug(
"automatic dynamic %s dim %s != %s",
name,
e.ndim,
frame_state_entry.size,
)
frame_state_entry.size = None
else:
# If there is already an entry, and the dim matches, for every size in the frame state which
# disagrees with the current static size, replace it with None. E.g., {"x": [2, 3]} -> {"x": [2, None]}
for i, dim in enumerate(frame_state_entry.size):
if dim is not None and e.size()[i] != dim:
log.debug(
"automatic dynamic %s size(%s) %s != %s",
name,
i,
e.size(i),
dim,
)
frame_state_entry.size[i] = None
# TODO: index export_constraints ahead of time so we don't have to
# do a linear scan every time here
t_id = id(e)
dim2constraint = {}
def update_dim2constraint(dim, constraint_range, debug_name):
if dim in dim2constraint:
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
old_constraint_range, old_debug_name = dim2constraint[dim]
new_constraint_range = StrictMinMaxConstraint(
vr=constraint_range.vr & old_constraint_range.vr,
warn_only=False,
)
# It is possible for (non-None) old_debug_name and debug_name to be different
# but this will only happen the corresponding Dims can be derived equal.
new_debug_name = old_debug_name or debug_name
dim2constraint[dim] = new_constraint_range, new_debug_name
else:
dim2constraint[dim] = constraint_range, debug_name
if tx.output.export_constraints:
for constraint in tx.output.export_constraints:
if constraint.t_id == t_id:
update_dim2constraint(
constraint.dim, constraint.constraint_range, constraint.debug_name
)
if constraint.shared is not None and constraint.shared.t_id == t_id:
# We process constraint ranges for each shared dimension separately
# so that we can directly check range constraint violations on them
# without looking up which other shared dimensions have this info.
# In other words, for this t_id, we will have processed all of its
# constraint ranges, no matter where / how they were specified, by
# by the end of this loop.
update_dim2constraint(
constraint.shared.dim,
constraint.constraint_range,
constraint.debug_name,
)
dynamic_dims = []
constraint_dims = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())
# NB: both static and dynamic have precedence over
automatic_dynamic = config.automatic_dynamic_shapes and (
frame_state_entry.size is None or frame_state_entry.size[i] is None
)
# Reflect the user directive in the frame_state
# For dynamic, apply None always
if frame_state_entry.size and marked_dynamic:
log.debug("automatic dynamic %s marked dynamic", name)
frame_state_entry.size[i] = None
# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
constraint = dim2constraint.get(i)
if constraint is None:
if marked_dynamic and not config.allow_ignore_mark_dynamic:
if hasattr(e, "_dynamo_dynamic_range"):
dim_range = [
dr for dr in e._dynamo_dynamic_range if dr.dim == i
].pop()
if dim_range.min is None and dim_range.max is None:
constraint_dim = RelaxedUnspecConstraint(warn_only=False)
else:
from torch.fx.experimental.symbolic_shapes import (
StrictMinMaxConstraint,
)
constraint_dim = StrictMinMaxConstraint(
vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
warn_only=False,
)
else:
constraint_dim = RelaxedUnspecConstraint(warn_only=False)
elif not marked_static and automatic_dynamic:
constraint_dim = RelaxedUnspecConstraint(warn_only=True)
else:
constraint_dim = None
else:
constraint_dim, debug_name = constraint
if debug_name is not None:
dim_name = f"{name}.size()[{i}]"
tx.output.shape_env.source_name_to_debug_name[dim_name] = debug_name
constraint_dims.append(constraint_dim)
# Now, figure out if the dim is dynamic/duck/static
if (
constraint_dim is not None
or marked_dynamic
or marked_weak_dynamic
or is_nested_int(e.shape[i])
):
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override symbolic_context in this
# case
dynamic = DimDynamic.DYNAMIC
elif static_shapes or config.assume_static_by_default or marked_static:
dynamic = DimDynamic.STATIC
else:
dynamic = DimDynamic.DUCK
dynamic_dims.append(dynamic)
tx.output.frame_state[name] = frame_state_entry
return StatefulSymbolicContext(
dynamic_sizes=dynamic_dims,
constraint_sizes=constraint_dims,
view_base_context=view_base_context,
tensor_source=source,
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
)
# See note [Tensor Fakification and Symbol Caching]
def wrap_to_fake_tensor_and_record(
e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None
):
if (
type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
or isinstance(e, torch.Tensor)
or is_traceable_wrapper_subclass(e)
):
assert source is not None
static_shapes, reason = tensor_always_has_static_shape(
e, is_tensor, guard_source=source.guard_source()
)
if not parent_context:
symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
else:
# Parent contexts are passed in when we are recursively creating
# fake tensors for subclasses. A better design would be not to create a
# parent/child relationship, but to recursively call _automatic_dynamic
# as we recursively call wrap_to_fake_tensor_and_record. This runs
# into bugs around how meta_utils knows and works to create fake tensors
# with tensor subclasses. Ideally, dynamo would drive both the recursive
# wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation.
assert isinstance(source, AttrSource)
inner_context_name = source.member
symbolic_context = parent_context.inner_contexts[inner_context_name]
log.debug(
"wrap_to_fake %s %s %s %s",
source.name(),
tuple(e.shape),
symbolic_context,
type(e),
)
fake_e = wrap_fake_exception(
lambda: tx.fake_mode.from_tensor(
e,
source=source,
symbolic_context=symbolic_context,
)
)
if (
source is not None
and isinstance(fake_e, FakeTensor)
and (sym_val := fake_e.item_memo) is not None
):
tx.output.tracked_fakes.append(
TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context)
)
if is_traceable_wrapper_subclass(fake_e):
attrs, _ = fake_e.__tensor_flatten__()
for attr in attrs:
fake_inner = getattr(fake_e, attr)
inner = getattr(e, attr)
inner_source = AttrSource(source, attr)
wrap_to_fake_tensor_and_record(
inner,
tx,
source=inner_source,
is_tensor=isinstance(fake_inner, torch.Tensor),
parent_context=symbolic_context,
)
tx.output.tracing_context.tensor_to_context[e] = symbolic_context
if is_sparse_any(fake_e):
# TODO: for TensorGuards, this eventually may need more
# fields for the size/stride of any other constituents
values = fake_e._values() if fake_e.is_sparse else fake_e.values()
tx.output.input_source_to_sizes_strides[source] = {
"size": fake_e.size(),
# TODO: revise this, but for now this stride instead of ()
# avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1
"stride": (1,) * fake_e.ndim,
"values_size": values.size(),
"values_stride": values.stride(),
}
else:
tx.output.input_source_to_sizes_strides[source] = {
"size": fake_e.size(),
"stride": fake_e.stride(),
}
if (
is_tensor
and not (static_shapes and source.is_nn_module())
and not is_constant_source(source)
):
tx.output.tracked_fakes.append(
TrackedFake(fake_e, source, symbolic_context)
)
tx.output.tracked_fakes_id_to_source[id(e)].append(source)
return fake_e
else:
return e
class SourcelessBuilder:
"""
Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects
that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over
.), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However,
there may be reasons to represent it as a ListVariable internally.
NOTE - Objects produced here are born UNGUARDED due to the nature of sources!
NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant
if/else type->VariableTracker trees that were cropping up all over dynamo.
"""
def __init__(self):
raise AssertionError("Use SourcelessBuilder.create()")
@staticmethod
def create(tx, value) -> VariableTracker:
value_type = type(value)
fast_handler = SourcelessBuilder._type_handlers.get(value_type)
if fast_handler:
return fast_handler(tx, value)
if isinstance(value, VariableTracker):
# This is always valid to call, and useful for recursive calls.
return value
elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS):
return UserDefinedObjectVariable(value)
elif ConstantVariable.is_literal(value):
return ConstantVariable.create(value)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if is_callable_allowed(value):
tx.output.has_user_defined_allowed_in_graph = True
return trace_rules.lookup_callable(value)(value)
elif is_function_or_wrapper(value):
return trace_rules.lookup(value)(value)
elif isinstance(value, enum.Enum):
return EnumVariable(value)
elif isinstance(value, (type, abc.ABCMeta)):
return UserDefinedClassVariable(value)
elif isinstance(value, types.MethodWrapperType):
return MethodWrapperVariable(value)
elif isinstance(value, torch.fx.graph_module.GraphModule):
return SourcelessGraphModuleVariable(value)
elif isinstance(
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
):
return UserDefinedObjectVariable(value)
elif PlacementVariable.is_placement(value):
return PlacementVariable(value)
elif DeviceMeshVariable.is_device_mesh(value):
return DeviceMeshVariable(value)
elif isinstance(value, re.Pattern):
return RegexPatternVariable(value)
unimplemented(
f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}"
)
@staticmethod
def wrap_constant_literal(value):
assert ConstantVariable.is_literal(value)
return ConstantVariable.create(value=value)
@staticmethod
def make_type_handlers():
create = SourcelessBuilder.create
handlers = {}
for t in common_constant_types:
handlers[t] = lambda tx, value: ConstantVariable(value)
handlers[set] = lambda tx, value: SetVariable(
[create(tx, x) for x in value], mutable_local=MutableLocal()
)
handlers[dict] = lambda tx, value: ConstDictVariable(
{create(tx, k): create(tx, v) for k, v in value.items()},
type(value),
mutable_local=MutableLocal(),
)
handlers[list] = lambda tx, value: ListVariable(
[create(tx, x) for x in value], mutable_local=MutableLocal()
)
handlers[tuple] = lambda tx, value: TupleVariable(
[create(tx, x) for x in value]
)
handlers[torch.Size] = lambda tx, value: SizeVariable(
[create(tx, x) for x in value]
)
handlers[collections.OrderedDict] = handlers[dict]
handlers[immutable_dict] = handlers[dict]
handlers[immutable_list] = handlers[list]
handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
def passthrough(tx, value):
return value
for cls in VariableTrackerMeta.all_subclasses:
handlers[cls] = passthrough
return handlers
SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()
|