File size: 41,113 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
import dataclasses
import functools
import inspect
import logging
import re
import time
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
import torch._dynamo
import torch.fx

import torch.utils._pytree as pytree
from torch._dynamo.exc import UserError, UserErrorType
from torch._export.non_strict_utils import (
    make_constraints,
    make_fake_inputs,
    make_fake_params_buffers,
)
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
    _AddRuntimeAssertionsForInlineConstraintsPass,
)
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
from torch._export.passes.lift_constants_pass import (
    ConstantAttrMap,
    lift_constants_pass,
    rewrite_script_object_meta,
)
from torch._export.wrappers import _wrap_submodules
from torch._functorch.aot_autograd import aot_export_module
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._utils_internal import log_export_usage
from torch.export.exported_program import OutputKind
from torch.fx.experimental.symbolic_shapes import (
    ConstraintViolationError,
    free_unbacked_symbols,
    GuardOnDataDependentSymNode,
    ShapeEnv,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.utils._sympy.value_ranges import ValueRangeError

from ._safeguard import AutogradStateOpsFailSafeguard

from .dynamic_shapes import _process_constraints, Constraint
from .exported_program import (
    _disable_prexisiting_fake_mode,
    ExportedProgram,
    InputKind,
    ModuleCallEntry,
    ModuleCallSignature,
)
from .graph_signature import (
    _sig_to_specs,
    ArgumentSpec,
    ConstantArgument,
    CustomObjArgument,
    ExportGraphSignature,
    SymIntArgument,
    TensorArgument,
)


log = logging.getLogger(__name__)


@dataclasses.dataclass
class ExportDynamoConfig:
    """

    Manage Export-specific configurations of Dynamo.

    """

    allow_rnn: bool = True
    reorderable_logging_functions: Set[Callable] = dataclasses.field(
        default_factory=set
    )


DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
    logging.critical,
    logging.debug,
    logging.error,
    logging.exception,
    logging.info,
    logging.log,
    logging.warning,
    print,
    warnings.warn,
}


@contextmanager
def _ignore_backend_decomps():
    orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
    orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
    try:
        yield
    finally:
        torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
        torch.backends.nnpack.set_flags(*orig_nnpack_flag)


def _convert_input_to_fake(gm, args, kwargs):
    params_buffers = _get_params_buffers(gm)
    fake_inps: List[torch.Tensor] = []
    for node in gm.graph.nodes:
        if node.op == "placeholder" and "val" in node.meta:
            fake_val = node.meta["val"]
            if fake_val is not None and isinstance(fake_val, torch.Tensor):
                fake_inps.append(fake_val)

    if detected_fake_mode := detect_fake_mode(fake_inps):
        fake_mode = detected_fake_mode
    else:
        fake_mode = FakeTensorMode(shape_env=ShapeEnv())

    if len(args) == 0 and len(kwargs) == 0:
        return (), {}, params_buffers, fake_mode

    count = 0

    def convert_to_fake(x):
        nonlocal count
        val = fake_inps[count]
        count += 1
        return val

    fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
    # TODO properly use the cached fake tensor
    fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
    fake_params_buffers = pytree.tree_map_only(
        torch.Tensor,
        functools.partial(fake_mode.from_tensor, static_shapes=True),
        params_buffers,
    )
    return fake_args, fake_kwargs, fake_params_buffers, fake_mode


def _replace_param_buffer_names(param_buffer_table, sig):
    for spec in sig.input_specs:
        if spec.kind in (
            InputKind.PARAMETER,
            InputKind.BUFFER,
        ):
            spec.target = param_buffer_table[spec.target]
    for spec in sig.output_specs:
        if spec.kind in (
            OutputKind.BUFFER_MUTATION,
            OutputKind.GRADIENT_TO_PARAMETER,
        ):
            spec.target = param_buffer_table[spec.target]


def _convert_to_positional_args(orig_arg_names, args, kwargs):
    assert len(orig_arg_names) == len(args) + len(kwargs), (
        f"Total number of arg names is expected to be {len(orig_arg_names)} "
        f"but got {len(args)} positional args, {len(kwargs)} kwargs."
    )
    reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
    return (
        *args,
        *reordered_kwargs,
    )


def _normalize_nn_module_stack(gm_torch_level, root_cls):
    # Append a root module to every nn_module_stack.
    root = "L['self']"
    root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
    for gm in gm_torch_level.modules():
        if not isinstance(gm, torch.fx.GraphModule):
            continue
        for node in gm.graph.nodes:
            if node.op in ["placeholder", "output"]:
                continue
            add_root = True
            if nn_module_stack := node.meta.get("nn_module_stack", {}):
                path, ty = next(iter(nn_module_stack.values()))
                # After deserializing the class `ty` might not exist anymore so
                # it could be a string
                if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
                    # TODO Figure out why sometimes we have root sometimes we don't.
                    if path == root and ty is root_cls:
                        add_root = False
                else:
                    assert isinstance(ty, str)
            if add_root:

                def normalize_path(path):
                    try:
                        parts = []

                        class Path:
                            def __getattr__(self, name):
                                parts.append(name)
                                return self

                            def __getitem__(self, idx):
                                parts.append(str(idx))
                                return self

                        eval(path, {"L": {"self": Path()}})
                        return ".".join(parts)
                    except Exception:  # TODO(zhxchen17) Remove this.
                        return path

                nn_module_stack = {root_key: (root, root_cls), **nn_module_stack}
                node.meta["nn_module_stack"] = {
                    key: (normalize_path(path), ty)
                    for key, (path, ty) in nn_module_stack.items()
                }


def _get_param_buffer_mapping(

    original_module: torch.nn.Module,

    traced_module: torch.nn.Module,

) -> Dict[str, str]:
    """

    Returns a mapping of parameter/buffer names from the new module to the

    original model. This is to help with restoring the FQN for parameter/buffers

    of a traced module to what the original module contains.

    """

    param_lookup: Dict[int, List[str]] = {}
    buffer_lookup: Dict[int, List[str]] = {}
    for name, param in original_module.named_parameters(remove_duplicate=False):
        param_lookup.setdefault(id(param), []).append(name)
    for name, buffer in original_module.named_buffers(remove_duplicate=False):
        buffer_lookup.setdefault(id(buffer), []).append(name)

    param_buffer_table: Dict[str, str] = {}
    for dynamo_name, dynamo_param in traced_module.named_parameters(
        remove_duplicate=False
    ):
        assert dynamo_name not in param_buffer_table
        if id(dynamo_param) in param_lookup:
            param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop()

    for dynamo_name, dynamo_buffer in traced_module.named_buffers(
        remove_duplicate=False
    ):
        assert dynamo_name not in param_buffer_table
        if id(dynamo_buffer) in buffer_lookup:
            param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop()

    return param_buffer_table


def _remap_constants(

    orig_constant_attrs: ConstantAttrMap,

    graph_signature: ExportGraphSignature,

    constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],

) -> None:
    """Rewrite the graph signature and constants table to use the FQN from the original module."""
    remap_table: Dict[str, str] = {}
    for name, value in constants.items():
        if value in orig_constant_attrs:
            remap_table[name] = orig_constant_attrs[value]

    for spec in graph_signature.input_specs:
        if spec.kind in (
            InputKind.CONSTANT_TENSOR,
            InputKind.CUSTOM_OBJ,
        ):
            orig_target = spec.target
            assert orig_target is not None
            spec.target = remap_table.get(orig_target, orig_target)

            constant = constants[orig_target]
            del constants[orig_target]
            constants[spec.target] = constant


def _restore_state_dict(

    original_module: torch.nn.Module, traced_module: torch.fx.GraphModule

) -> None:
    """

    Restores the state dict of the traced module to that of the original module.

    """
    param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
    # Since the graph module is flattened (no module heirarchy), we
    # need to noramlize the module by replacing "." with "_". If we
    # don't, it will try to save the weight to a submodule which no
    # longer exists.
    for name, fqn in param_buffer_table.items():
        param_buffer_table[name] = fqn.replace(".", "_")

    # Replace state dict attr names with the fqn
    for name, fqn in param_buffer_table.items():
        if not hasattr(traced_module, name):
            continue

        attr = getattr(traced_module, name)
        if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
            traced_module.register_buffer(fqn, attr)
        else:
            setattr(traced_module, fqn, attr)
        delattr(traced_module, name)

    # Replace graph getattr nodes with the correct name
    for node in traced_module.graph.nodes:
        if node.op == "get_attr":
            attr_name = node.target
            if attr_name in param_buffer_table:
                node.target = param_buffer_table[attr_name]

    traced_module.recompile()


def _export_to_torch_ir(

    f: Callable,

    args: Tuple[Any, ...],

    kwargs: Optional[Dict[str, Any]] = None,

    constraints: Optional[List[Constraint]] = None,

    *,

    preserve_module_call_signature: Tuple[str, ...] = (),

    disable_constraint_solver: bool = False,

    restore_fqn: bool = True,

    _log_export_usage: bool = True,

) -> torch.fx.GraphModule:
    """

    Traces either an nn.Module's forward function or just a callable with PyTorch

    operations inside and produce a torch.fx.GraphModule in torch IR.

    """

    if _log_export_usage:
        log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})

    kwargs = kwargs or {}

    if not isinstance(args, tuple):
        raise UserError(
            UserErrorType.INVALID_INPUT,
            f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
        )

    with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
        try:
            module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
            with _wrap_submodules(
                f, preserve_module_call_signature, module_call_specs
            ), _ignore_backend_decomps():
                gm_torch_level, _ = torch._dynamo.export(
                    f,
                    constraints=constraints,  # type: ignore[arg-type]
                    assume_static_by_default=True,
                    tracing_mode="symbolic",
                    disable_constraint_solver=disable_constraint_solver,
                    _log_export_usage=_log_export_usage,
                )(
                    *args,
                    **kwargs,
                )
        except (ConstraintViolationError, ValueRangeError) as e:
            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: TRY200
        except GuardOnDataDependentSymNode as e:
            raise UserError(  # noqa: TRY200
                UserErrorType.ANTI_PATTERN,
                f"Consider annotating your code using torch._constrain_as_*(). {str(e)}",
                case_name="constrain_as_size_example",
            )

    gm_torch_level.meta["module_call_specs"] = module_call_specs

    if isinstance(f, torch.nn.Module) and restore_fqn:
        _restore_state_dict(f, gm_torch_level)

    return gm_torch_level


def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
    """Search the module hierarchy, gathering up all tensor and ScriptObject constants.



    Returns a dictionary mapping hash(value) to the name of the constant. We

    have to abuse `hash` here unfortunately, see: [ScriptObject hash].

    """
    constants = ConstantAttrMap()
    buffers_parameters = set(m.buffers())
    buffers_parameters.update(m.parameters())

    def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
        for k, v in m.__dict__.items():
            if isinstance(v, (torch.Tensor, torch.ScriptObject)):
                if v in buffers_parameters:
                    # filter out buffers and parameters, leaving only constants
                    continue

                fqn = ".".join(prefix_atoms + [k])
                if v in constants:
                    raise ValueError(
                        f"Duplicate reference to constant attribute found: '{constants[v]}' and '{fqn}'."
                    )

                constants[v] = fqn
        for k, v in m.named_children():
            inner(v, prefix_atoms + [k], constants)

    inner(m, [], constants)
    return constants


def _export_non_strict(

    mod: torch.nn.Module,

    fake_args,

    fake_kwargs,

    fake_params_buffers,

    constant_attrs: ConstantAttrMap,

    *,

    transform=lambda x: x,  # TODO(zhxchen17) Revisit if this is needed later.

    pre_dispatch=False,

):
    # [NOTE] If the user is exporting under training mode, we want to detect if there is any
    # state change in the autograd global state and error. If the user is exporting under inference
    # mode, we don't care.
    is_grad_enabled = torch._C.is_grad_enabled()
    grad_safe_guard = (
        AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext()
    )

    @contextmanager
    def _compiling_state_context():
        old_value = torch.compiler._is_compiling_flag
        try:
            torch.compiler._is_compiling_flag = True
            yield
        finally:
            torch.compiler._is_compiling_flag = old_value

    # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
    # otherwise aot_export_module will error out because it sees a mix of fake_modes.
    # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
    with torch.nn.utils.stateless._reparametrize_module(
        mod, fake_params_buffers
    ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context():  # type: ignore[attr-defined]
        gm, graph_signature = transform(aot_export_module)(
            mod,
            fake_args,
            trace_joint=False,
            pre_dispatch=pre_dispatch,
            kwargs=fake_kwargs,
        )
    # TODO unfortunately preserving graph-level metadata is not
    # working well with aot_export. So we manually copy it.
    # (The node-level meta is addressed above.)
    if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
        gm.meta.update(mod.meta)

    if pre_dispatch:
        from torch._export.passes.replace_set_grad_with_hop_pass import (
            replace_set_grad_with_hop_pass,
        )

        gm = replace_set_grad_with_hop_pass(gm)

    # NOTE: aot_export adds symint metadata for placeholders with int values;
    # since these become specialized, we replace such metadata with the original values
    flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
    index = 0
    total_non_user_inputs = (
        len(graph_signature.parameters)
        + len(graph_signature.buffers)
        + len(graph_signature.input_tokens)
    )
    for node in gm.graph.nodes:
        if node.op == "placeholder":
            if index >= total_non_user_inputs:
                user_arg = flat_args[index - total_non_user_inputs]
                if not isinstance(user_arg, torch.Tensor):
                    node.meta["val"] = user_arg
            index += 1

    is_joint = graph_signature.backward_signature is not None

    def make_argument_spec(node) -> ArgumentSpec:
        if isinstance(node, (int, bool, float, type(None))):
            # For const outputs we just directly return this
            return ConstantArgument(value=node)

        assert (
            "val" in node.meta
        ), f"{node} is not a constant or a node with a 'val' metadata field"
        val = node.meta["val"]
        if isinstance(val, FakeTensor):
            return TensorArgument(name=node.name)
        elif isinstance(val, torch.SymInt):
            return SymIntArgument(name=node.name)
        elif isinstance(val, torch.ScriptObject):
            return CustomObjArgument(
                name=node.name, class_fqn=val._type().qualified_name()  # type: ignore[attr-defined]
            )
        else:
            # TODO: this branch is likely wrong, all permissible ConstantArgument type
            # should have been handled already
            return ConstantArgument(value=val)

    input_specs, output_specs = _sig_to_specs(
        user_inputs=set(graph_signature.user_inputs),
        inputs_to_parameters=graph_signature.inputs_to_parameters,  # type: ignore[arg-type]
        inputs_to_buffers=graph_signature.inputs_to_buffers,  # type: ignore[arg-type]
        user_outputs=set(graph_signature.user_outputs),  # type: ignore[arg-type]
        buffer_mutations=graph_signature.buffers_to_mutate,  # type: ignore[arg-type]
        user_input_mutations=graph_signature.user_inputs_to_mutate,  # type: ignore[arg-type]
        grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {},  # type: ignore[arg-type, union-attr]
        grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {},  # type: ignore[arg-type, union-attr]
        loss_output=graph_signature.backward_signature.loss_output if is_joint else None,  # type: ignore[arg-type, union-attr]
        inputs=[
            make_argument_spec(node)
            for node in gm.graph.nodes
            if node.op == "placeholder"
        ],
        outputs=[
            make_argument_spec(node)
            for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
        ],
        input_tokens=graph_signature.input_tokens,
        output_tokens=graph_signature.output_tokens,
    )
    export_graph_signature = ExportGraphSignature(
        input_specs=input_specs, output_specs=output_specs
    )

    constants = rewrite_script_object_meta(gm)
    constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))

    @dataclasses.dataclass
    class _ExportedProgramNonStrict:
        gm: torch.fx.GraphModule
        sig: ExportGraphSignature
        constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]

    return _ExportedProgramNonStrict(
        gm,
        export_graph_signature,
        constants,
    )


def _get_params_buffers(mod: torch.nn.Module) -> Dict[str, torch.Tensor]:
    params_buffers: Dict[str, torch.Tensor] = {}
    for name, param in mod.named_parameters(remove_duplicate=False):
        params_buffers[name] = param

    for name, buffer in mod.named_buffers(remove_duplicate=False):
        params_buffers[name] = buffer
    return params_buffers


def _rewrite_dynamo_tensor_constants(

    orig_mod_buffers: Set[torch.Tensor],

    traced_mod_buffers: Dict[str, torch.Tensor],

    graph_signature: ExportGraphSignature,

    constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],

):
    """Dynamo erroneously marks tensor attributes on modules as a buffers.



    Rewrite them to be tensor constants.

    """
    for spec in graph_signature.input_specs:
        if spec.kind == InputKind.BUFFER:
            assert spec.target is not None
            value = traced_mod_buffers[spec.target]
            if value not in orig_mod_buffers:
                # This was a tensor constant erroneously marked as a buffer.
                # Convert it int oa constant in the graph signature, and add its
                # value to the constants table.
                spec.kind = InputKind.CONSTANT_TENSOR
                constants[spec.target] = value


def _rewrite_non_persistent_buffers(

    orig_mod: torch.nn.Module,

    graph_signature: ExportGraphSignature,

    constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],

):
    """Dynamo erroneously drops the persistent flag on buffers.



    Rewrite non-persistent buffers to reflect the original module.

    """
    state_dict = orig_mod.state_dict()
    for spec in graph_signature.input_specs:
        if spec.kind == InputKind.BUFFER:
            assert spec.target is not None
            if spec.target not in state_dict:
                assert spec.target not in constants
                spec.persistent = False
                constants[spec.target] = orig_mod.get_buffer(spec.target)


def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
    op_count = 0
    op_set = set()
    for m in ep.graph_module.modules():
        if not isinstance(m, torch.fx.GraphModule):
            continue
        for node in m.graph.nodes:
            if node.op != "call_function":
                continue
            op_count += 1
            assert hasattr(node.target, "__module__")
            assert hasattr(node.target, "__name__")
            op_set.add(f"{node.target.__module__}.{node.target.__name__}")
    return {"op_count": op_count, "op_set": op_set}


_EXPORT_FLAGS: Optional[Set[str]] = None


def _log_export_wrapper(fn):
    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        global _EXPORT_FLAGS
        try:
            start = time.time()
            ep = fn(*args, **kwargs)
            end = time.time()
            log_export_usage(
                event="export.time",
                metrics=end - start,
                flags=_EXPORT_FLAGS,
                **get_ep_stats(ep),
            )
        except Exception as e:
            t = type(e)
            error_type = t.__module__ + "." + t.__qualname__
            log_export_usage(
                event="export.error",
                type=error_type,
                message=str(e),
                flags=_EXPORT_FLAGS,
            )
            raise e
        finally:
            _EXPORT_FLAGS = None

        return ep

    return wrapper


@_log_export_wrapper
@_disable_prexisiting_fake_mode
def _export(

    mod: torch.nn.Module,

    args: Tuple[Any, ...],

    kwargs: Optional[Dict[str, Any]] = None,

    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,

    *,

    strict: bool = True,

    preserve_module_call_signature: Tuple[str, ...] = (),

    pre_dispatch: bool = False,

) -> ExportedProgram:
    """

    Traces either an nn.Module's forward function or just a callable with PyTorch

    operations inside and produce a ExportedProgram.



    Args:

        f: the `nn.Module` to trace.



        args: example positional inputs.



        kwargs: optional example keyword inputs.



        dynamic_shapes:

         An optional argument where the type should either be:

         1) a dict from argument names of ``f`` to their dynamic shape specifications,

         2) a tuple that specifies dynamic shape specifications for each input in original order.

         If you are specifying dynamism on keyword args, you will need to pass them in the order that

         is defined in the original function signature.



         The dynamic shape of a tensor argument can be specified as either

         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is

         not required to include static dimension indices in this dict, but when they are,

         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,

         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions

         are denoted by None. Arguments that are dicts or tuples / lists of tensors are

         recursively specified by using mappings or sequences of contained specifications.



        preserve_module_call_signature: A list of submodule paths for which the original

            calling conventions are preserved as metadata.



    Returns:

        An ExportedProgram containing the traced method.

    """
    from .dynamic_shapes import _process_dynamic_shapes

    global _EXPORT_FLAGS
    flags = set()
    flags.add("strict" if strict else "non_strict")
    flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch")
    log_export_usage(event="export.enter", flags=flags)
    _EXPORT_FLAGS = flags

    constraints = _process_dynamic_shapes(mod, args, kwargs, dynamic_shapes) or []

    kwargs = kwargs or {}

    constant_attrs = _gather_constant_attrs(mod)

    flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))

    if not strict:
        out_spec = None

        module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}

        def strip_root(x):
            if isinstance(x, str) and x.startswith("_export_root"):
                stripped = x[len("_export_root") :]
                return stripped[1:] if stripped.startswith(".") else stripped
            return x

        def fixup_key(x):
            return "L__self__" + strip_root(x)

        def _tuplify_outputs(aot_export):
            def _aot_export_non_strict(mod, args, kwargs=None, **flags):
                kwargs = kwargs or {}

                class Wrapper(torch.nn.Module):
                    def __init__(self, mod):
                        super().__init__()
                        self._export_root = mod

                    def forward(self, *args, **kwargs):
                        nonlocal out_spec
                        if isinstance(self._export_root, torch.fx.GraphModule):
                            with torch.fx.traceback.preserve_node_meta():
                                tree_out = torch.fx.Interpreter(self._export_root).run(
                                    *args, **kwargs
                                )
                        else:
                            tree_out = self._export_root(*args, **kwargs)
                        flat_outs, out_spec = pytree.tree_flatten(tree_out)
                        return tuple(flat_outs)

                wrapped_mod = Wrapper(mod)
                # Patch export_root to the signatures so that wrapper module correctly populates the
                # in/out spec
                new_preserved_call_signatures = [
                    "_export_root." + i for i in preserve_module_call_signature
                ]
                with _wrap_submodules(
                    wrapped_mod, new_preserved_call_signatures, module_call_specs
                ):
                    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)

                sig.parameters = pytree.tree_map(strip_root, sig.parameters)
                sig.buffers = pytree.tree_map(strip_root, sig.buffers)
                sig.inputs_to_buffers = pytree.tree_map(
                    strip_root, sig.inputs_to_buffers
                )
                sig.inputs_to_parameters = pytree.tree_map(
                    strip_root, sig.inputs_to_parameters
                )
                sig.buffers_to_mutate = pytree.tree_map(
                    strip_root, sig.buffers_to_mutate
                )
                for node in gm.graph.nodes:
                    if "nn_module_stack" in node.meta:
                        nn_module_stack = node.meta["nn_module_stack"]
                        node.meta["nn_module_stack"] = {
                            fixup_key(key): val
                            for key, val in pytree.tree_map(
                                strip_root, nn_module_stack
                            ).items()
                        }

                return gm, sig

            return _aot_export_non_strict

        (
            fake_mode,
            fake_args,
            fake_kwargs,
            equalities_inputs,
            original_signature,
        ) = make_fake_inputs(mod, args, kwargs, constraints)

        fake_params_buffers = make_fake_params_buffers(
            fake_mode, _get_params_buffers(mod)
        )
        with fake_mode:
            ep_non_strict = _export_non_strict(
                mod,
                fake_args,
                fake_kwargs,
                fake_params_buffers,
                constant_attrs,
                pre_dispatch=pre_dispatch,
                transform=_tuplify_outputs,
            )
        try:
            range_constraints = make_constraints(
                fake_mode,
                equalities_inputs,
                original_signature,
                ep_non_strict.gm,
            )
        except (ConstraintViolationError, ValueRangeError) as e:
            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: TRY200

        assert out_spec is not None

        gm = ep_non_strict.gm

        module_call_signatures = {
            strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs)
            for fqn, specs in module_call_specs.items()
        }

        if len(preserve_module_call_signature) > 0:
            for node in gm.graph.nodes:
                if node.target == torch.ops.higher_order._export_tracepoint:
                    if "path" in node.kwargs:
                        path = strip_root(node.kwargs["path"])
                        with gm.graph.inserting_before(node):
                            new_node = gm.graph.create_node(
                                "call_function",
                                torch.ops.higher_order._export_tracepoint,
                                args=node.args,
                                kwargs={
                                    "path": path,
                                    "kind": node.kwargs["kind"],
                                },
                            )
                            node.replace_all_uses_with(new_node)
                            gm.graph.erase_node(node)

            res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm)
            assert res is not None
            gm = res.graph_module

        _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants)
        return ExportedProgram(
            root=gm,
            graph=gm.graph,
            graph_signature=ep_non_strict.sig,
            state_dict=mod.state_dict(keep_vars=True),
            range_constraints=range_constraints,
            module_call_graph=[
                ModuleCallEntry(
                    "",
                    ModuleCallSignature(
                        inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec
                    ),
                )
            ]
            + [
                ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()
            ],
            example_inputs=(args, kwargs),
            constants=ep_non_strict.constants,
        )

    gm_torch_level = _export_to_torch_ir(
        mod,
        args,
        kwargs,
        constraints,
        preserve_module_call_signature=preserve_module_call_signature,
        restore_fqn=False,  # don't need to restore because we will do it later
        _log_export_usage=False,
    )

    # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
    (
        fake_args,
        fake_kwargs,
        fake_params_buffers,
        dynamo_fake_mode,
    ) = _convert_input_to_fake(gm_torch_level, args, kwargs)

    # First, we want to pass through the graph to try populating
    # val field for getattr if there is anything missing.
    # This can happen when quantization adds extra params and forgets
    # to update "val"
    for node in gm_torch_level.graph.nodes:
        if node.op == "get_attr" and "val" not in node.meta:
            attr = getattr(gm_torch_level, node.target)
            # Checks if it is not a HigherOrderOp branch or a module
            if not isinstance(attr, torch.nn.Module):
                assert (
                    dynamo_fake_mode is not None
                ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
                node.meta["val"] = dynamo_fake_mode.from_tensor(
                    attr, static_shapes=True
                )

    # When aot_export lifts the params, we lose the nn_module_stack
    # and source_fn from the param nodes as they are treated as fresh inputs
    # Therefore, we manually extract them before calling into aot_export
    params_buffers_to_node_meta = {}
    for node in gm_torch_level.graph.nodes:
        target = node.target
        meta = node.meta
        if node.op == "call_module":
            submodule = getattr(gm_torch_level, target)
            if isinstance(submodule, torch.nn.Module):
                for name, _ in submodule.named_parameters(
                    recurse=True, remove_duplicate=False
                ):
                    params_buffers_to_node_meta[target + "." + name] = meta

                for name, _ in submodule.named_buffers(
                    recurse=True, remove_duplicate=False
                ):
                    params_buffers_to_node_meta[target + "." + name] = meta

        if node.op == "get_attr":
            submodule = getattr(gm_torch_level, target)
            if not isinstance(submodule, torch.fx.GraphModule):
                params_buffers_to_node_meta[target] = meta

        # If the call_function uses param as input, we also need to update params' meta
        # with this call_function node's meta.
        # This is basically the same flow as torch.fx.traceback.preserve_meta()
        if node.op == "call_function" and not isinstance(
            node.target, torch._ops.HigherOrderOperator
        ):
            for arg in node._input_nodes:
                if arg.op == "get_attr":
                    for entry in torch.fx.proxy._COPY_META_FIELDS:
                        if entry in meta:
                            params_buffers_to_node_meta[arg.target][entry] = meta[entry]

    # Fix the graph output signature to be tuple if scalar
    out_spec = orig_out_spec = gm_torch_level._out_spec
    assert out_spec is not None
    # aot_export expect the return type to always be a tuple.
    if out_spec.type not in (list, tuple):
        out_spec = pytree.TreeSpec(tuple, None, [out_spec])

    orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args  # type: ignore[attr-defined]

    gm_torch_level.graph._codegen = _PyTreeCodeGen(
        _PyTreeInfo(
            orig_arg_names,
            gm_torch_level._in_spec,
            out_spec,
        )
    )
    gm_torch_level.recompile()

    _normalize_nn_module_stack(gm_torch_level, type(mod))

    # NOTE: graph module expects only positional args
    ep_non_strict = _export_non_strict(
        gm_torch_level,
        _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs),
        {},
        fake_params_buffers,
        constant_attrs,
        pre_dispatch=pre_dispatch,
    )

    gm = ep_non_strict.gm
    export_graph_signature = ep_non_strict.sig
    constants = ep_non_strict.constants

    # After aot_export, set the param/buffer metadata back into placeholders
    # Technically, users can still construct this data from param names
    # without relying on this metadata
    for node in gm.graph.nodes:
        if node.op == "placeholder":
            if node.target in export_graph_signature.inputs_to_parameters:
                param_name = export_graph_signature.inputs_to_parameters[node.target]
                if param_name in params_buffers_to_node_meta:
                    for k, v in params_buffers_to_node_meta[param_name].items():
                        node.meta[k] = v
            if node.target in export_graph_signature.inputs_to_buffers:
                buffer_name = export_graph_signature.inputs_to_buffers[node.target]
                if buffer_name in params_buffers_to_node_meta:
                    for k, v in params_buffers_to_node_meta[buffer_name].items():
                        node.meta[k] = v

    # The unbacked symint symbols are updated in aot_export
    # so we serialize them here instead of inside dynamo

    gm.meta["inline_constraints"] = {
        k: v
        for k, v in dynamo_fake_mode.shape_env.var_to_range.items()
        if free_unbacked_symbols(k)
    }

    num_lifted = next(
        (
            i
            for i, s in enumerate(export_graph_signature.input_specs)
            if s.kind == InputKind.USER_INPUT
        ),
        len(export_graph_signature.input_specs),
    )
    range_constraints = _process_constraints(
        dynamo_fake_mode,
        gm,
        num_lifted,
        flat_args,
    )

    # Do some cleanups on the graph module to restore the state dict to the
    # expected form. Each of these steps should probably get fixed upstream.
    # 1. Remove tensor constants that were added as buffers.
    _rewrite_dynamo_tensor_constants(
        orig_mod_buffers=set(mod.buffers()),
        traced_mod_buffers=dict(gm_torch_level.named_buffers()),
        graph_signature=ep_non_strict.sig,
        constants=ep_non_strict.constants,
    )
    # 2. Restore FQN of param/buffers
    param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
    _replace_param_buffer_names(param_buffer_table, export_graph_signature)

    # 3. Remove non-persistent buffers from the graph signature
    _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants)

    # 4. Rewrite constants to have the same FQN as the original module.
    _remap_constants(constant_attrs, export_graph_signature, constants)

    module_call_signatures = {
        fqn: ModuleCallSignature(inputs=[], outputs=[], **specs)
        for fqn, specs in gm_torch_level.meta["module_call_specs"].items()
    }

    if len(preserve_module_call_signature) > 0:
        res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
        assert res is not None
        gm = res.graph_module

    assert orig_out_spec is not None
    exported_program = ExportedProgram(
        root=gm,
        graph=gm.graph,
        graph_signature=export_graph_signature,
        state_dict=mod.state_dict(keep_vars=True),
        range_constraints=range_constraints,
        module_call_graph=[
            ModuleCallEntry(
                "",
                ModuleCallSignature(
                    inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=orig_out_spec
                ),
            )
        ]
        + [ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()],
        example_inputs=(args, kwargs),
        constants=constants,
    )
    log.debug("Exported program from AOTAutograd:\n%s", exported_program)

    if len(range_constraints) > 0:
        exported_program = exported_program._transform_do_not_use(
            _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints)
        )

    return exported_program