File size: 34,318 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
#!/usr/bin/env python3
import threading
import typing
import warnings
from collections import defaultdict
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
    _reduce_list,
    _run_forward,
    _sort_key_list,
    _verify_select_neuron,
)
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
    Literal,
    ModuleOrModuleList,
    TargetType,
    TensorOrTupleOfTensorsGeneric,
)
from torch import device, Tensor
from torch.nn import Module


def apply_gradient_requirements(
    inputs: Tuple[Tensor, ...], warn: bool = True
) -> List[bool]:
    """
    Iterates through tuple on input tensors and sets requires_grad to be true on
    each Tensor, and ensures all grads are set to zero. To ensure that the input
    is returned to its initial state, a list of flags representing whether or not
     a tensor originally required grad is returned.
    """
    assert isinstance(
        inputs, tuple
    ), "Inputs should be wrapped in a tuple prior to preparing for gradients"
    grad_required = []
    for index, input in enumerate(inputs):
        assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
        grad_required.append(input.requires_grad)
        inputs_dtype = input.dtype
        # Note: torch 1.2 doesn't support is_complex for dtype that's why we check
        # on the existance of is_complex method.
        if not inputs_dtype.is_floating_point and not (
            hasattr(inputs_dtype, "is_complex") and inputs_dtype.is_complex
        ):
            if warn:
                warnings.warn(
                    """Input Tensor %d has a dtype of %s.
                    Gradients cannot be activated
                    for these data types."""
                    % (index, str(inputs_dtype))
                )
        elif not input.requires_grad:
            if warn:
                warnings.warn(
                    "Input Tensor %d did not already require gradients, "
                    "required_grads has been set automatically." % index
                )
            input.requires_grad_()
    return grad_required


def undo_gradient_requirements(
    inputs: Tuple[Tensor, ...], grad_required: List[bool]
) -> None:
    """
    Iterates through list of tensors, zeros each gradient, and sets required
    grad to false if the corresponding index in grad_required is False.
    This method is used to undo the effects of prepare_gradient_inputs, making
    grads not required for any input tensor that did not initially require
    gradients.
    """

    assert isinstance(
        inputs, tuple
    ), "Inputs should be wrapped in a tuple prior to preparing for gradients."
    assert len(inputs) == len(
        grad_required
    ), "Input tuple length should match gradient mask."
    for index, input in enumerate(inputs):
        assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
        if not grad_required[index]:
            input.requires_grad_(False)


def compute_gradients(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
    r"""
    Computes gradients of the output with respect to inputs for an
    arbitrary forward function.

    Args:

        forward_fn: forward function. This can be for example model's
                    forward function.
        input:      Input at which gradients are evaluated,
                    will be passed to forward_fn.
        target_ind: Index of the target class for which gradients
                    must be computed (classification only).
        additional_forward_args: Additional input arguments that forward
                    function requires. It takes an empty tuple (no additional
                    arguments) if no additional arguments are required
    """
    with torch.autograd.set_grad_enabled(True):
        # runs forward pass
        outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
        assert outputs[0].numel() == 1, (
            "Target not provided when necessary, cannot"
            " take gradient with respect to multiple outputs."
        )
        # torch.unbind(forward_out) is a list of scalar tensor tuples and
        # contains batch_size * #steps elements
        grads = torch.autograd.grad(torch.unbind(outputs), inputs)
    return grads


def _neuron_gradients(
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    saved_layer: Dict[device, Tuple[Tensor, ...]],
    key_list: List[device],
    gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
) -> Tuple[Tensor, ...]:
    with torch.autograd.set_grad_enabled(True):
        gradient_tensors = []
        for key in key_list:
            current_out_tensor = _verify_select_neuron(
                saved_layer[key], gradient_neuron_selector
            )
            gradient_tensors.append(
                torch.autograd.grad(
                    torch.unbind(current_out_tensor)
                    if current_out_tensor.numel() > 1
                    else current_out_tensor,
                    inputs,
                )
            )
        _total_gradients = _reduce_list(gradient_tensors, sum)
    return _total_gradients


@typing.overload
def _forward_layer_eval(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: Module,
    additional_forward_args: Any = None,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    grad_enabled: bool = False,
) -> Tuple[Tensor, ...]:
    ...


@typing.overload
def _forward_layer_eval(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: List[Module],
    additional_forward_args: Any = None,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    grad_enabled: bool = False,
) -> List[Tuple[Tensor, ...]]:
    ...


def _forward_layer_eval(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: ModuleOrModuleList,
    additional_forward_args: Any = None,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    grad_enabled: bool = False,
) -> Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]:
    return _forward_layer_eval_with_neuron_grads(
        forward_fn,
        inputs,
        layer,
        additional_forward_args=additional_forward_args,
        gradient_neuron_selector=None,
        grad_enabled=grad_enabled,
        device_ids=device_ids,
        attribute_to_layer_input=attribute_to_layer_input,
    )


@typing.overload
def _forward_layer_distributed_eval(
    forward_fn: Callable,
    inputs: Any,
    layer: ModuleOrModuleList,
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    attribute_to_layer_input: bool = False,
    forward_hook_with_return: Literal[False] = False,
    require_layer_grads: bool = False,
) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]:
    ...


@typing.overload
def _forward_layer_distributed_eval(
    forward_fn: Callable,
    inputs: Any,
    layer: ModuleOrModuleList,
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    attribute_to_layer_input: bool = False,
    *,
    forward_hook_with_return: Literal[True],
    require_layer_grads: bool = False,
) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]:
    ...


def _forward_layer_distributed_eval(
    forward_fn: Callable,
    inputs: Any,
    layer: ModuleOrModuleList,
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    attribute_to_layer_input: bool = False,
    forward_hook_with_return: bool = False,
    require_layer_grads: bool = False,
) -> Union[
    Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor],
    Dict[Module, Dict[device, Tuple[Tensor, ...]]],
]:
    r"""
    A helper function that allows to set a hook on model's `layer`, run the forward
    pass and returns intermediate layer results, stored in a dictionary,
    and optionally also the output of the forward function. The keys in the
    dictionary are the device ids and the values are corresponding intermediate layer
    results, either the inputs or the outputs of the layer depending on whether we set
    `attribute_to_layer_input` to True or False.
    This is especially useful when we execute forward pass in a distributed setting,
    using `DataParallel`s for example.
    """
    saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict)
    lock = threading.Lock()
    all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer

    # Set a forward hook on specified module and run forward pass to
    # get layer output tensor(s).
    # For DataParallel models, each partition adds entry to dictionary
    # with key as device and value as corresponding Tensor.
    def hook_wrapper(original_module):
        def forward_hook(module, inp, out=None):
            eval_tsrs = inp if attribute_to_layer_input else out
            is_eval_tuple = isinstance(eval_tsrs, tuple)

            if not is_eval_tuple:
                eval_tsrs = (eval_tsrs,)
            if require_layer_grads:
                apply_gradient_requirements(eval_tsrs, warn=False)
            with lock:
                nonlocal saved_layer
                # Note that cloning behaviour of `eval_tsr` is different
                # when `forward_hook_with_return` is set to True. This is because
                # otherwise `backward()` on the last output layer won't execute.
                if forward_hook_with_return:
                    saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs
                    eval_tsrs_to_return = tuple(
                        eval_tsr.clone() for eval_tsr in eval_tsrs
                    )
                    if not is_eval_tuple:
                        eval_tsrs_to_return = eval_tsrs_to_return[0]
                    return eval_tsrs_to_return
                else:
                    saved_layer[original_module][eval_tsrs[0].device] = tuple(
                        eval_tsr.clone() for eval_tsr in eval_tsrs
                    )

        return forward_hook

    all_hooks = []
    try:
        for single_layer in all_layers:
            if attribute_to_layer_input:
                all_hooks.append(
                    single_layer.register_forward_pre_hook(hook_wrapper(single_layer))
                )
            else:
                all_hooks.append(
                    single_layer.register_forward_hook(hook_wrapper(single_layer))
                )
        output = _run_forward(
            forward_fn,
            inputs,
            target=target_ind,
            additional_forward_args=additional_forward_args,
        )
    finally:
        for hook in all_hooks:
            hook.remove()

    if len(saved_layer) == 0:
        raise AssertionError("Forward hook did not obtain any outputs for given layer")

    if forward_hook_with_return:
        return saved_layer, output
    return saved_layer


def _gather_distributed_tensors(
    saved_layer: Dict[device, Tuple[Tensor, ...]],
    device_ids: Union[None, List[int]] = None,
    key_list: Union[None, List[device]] = None,
) -> Tuple[Tensor, ...]:
    r"""
    A helper function to concatenate intermediate layer results stored on
    different devices in `saved_layer`. `saved_layer` is a dictionary that
    contains `device_id` as a key and intermediate layer results (either
    the input or the output of the layer) stored on the device corresponding to
    the key.
    `key_list` is a list of devices in appropriate ordering for concatenation
    and if not provided, keys are sorted based on device ids.

    If only one key exists (standard model), key list simply has one element.
    """
    if key_list is None:
        key_list = _sort_key_list(list(saved_layer.keys()), device_ids)
    return _reduce_list([saved_layer[device_id] for device_id in key_list])


def _extract_device_ids(
    forward_fn: Callable,
    saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]],
    device_ids: Union[None, List[int]],
) -> Union[None, List[int]]:
    r"""
    A helper function to extract device_ids from `forward_function` in case it is
    provided as part of a `DataParallel` model or if is accessible from
    `forward_fn`.
    In case input device_ids is not None, this function returns that value.
    """
    # Multiple devices / keys implies a DataParallel model, so we look for
    # device IDs if given or available from forward function
    # (DataParallel model object).
    if (
        max(len(saved_layer[single_layer]) for single_layer in saved_layer) > 1
        and device_ids is None
    ):
        if (
            hasattr(forward_fn, "device_ids")
            and cast(Any, forward_fn).device_ids is not None
        ):
            device_ids = cast(Any, forward_fn).device_ids
        else:
            raise AssertionError(
                "Layer tensors are saved on multiple devices, however unable to access"
                " device ID list from the `forward_fn`. Device ID list must be"
                " accessible from `forward_fn`. For example, they can be retrieved"
                " if `forward_fn` is a model of type `DataParallel`. It is used"
                " for identifying device batch ordering."
            )
    return device_ids


@typing.overload
def _forward_layer_eval_with_neuron_grads(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: Module,
    additional_forward_args: Any = None,
    *,
    gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
    grad_enabled: bool = False,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
    ...


@typing.overload
def _forward_layer_eval_with_neuron_grads(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: Module,
    additional_forward_args: Any = None,
    gradient_neuron_selector: None = None,
    grad_enabled: bool = False,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
) -> Tuple[Tensor, ...]:
    ...


@typing.overload
def _forward_layer_eval_with_neuron_grads(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: List[Module],
    additional_forward_args: Any = None,
    gradient_neuron_selector: None = None,
    grad_enabled: bool = False,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
) -> List[Tuple[Tensor, ...]]:
    ...


def _forward_layer_eval_with_neuron_grads(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    layer: ModuleOrModuleList,
    additional_forward_args: Any = None,
    gradient_neuron_selector: Union[
        None, int, Tuple[Union[int, slice], ...], Callable
    ] = None,
    grad_enabled: bool = False,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
) -> Union[
    Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
    Tuple[Tensor, ...],
    List[Tuple[Tensor, ...]],
]:
    """
    This method computes forward evaluation for a particular layer using a
    forward hook. If a gradient_neuron_selector is provided, then gradients with
    respect to that neuron in the layer output are also returned.

    These functionalities are combined due to the behavior of DataParallel models
    with hooks, in which hooks are executed once per device. We need to internally
    combine the separated tensors from devices by concatenating based on device_ids.
    Any necessary gradients must be taken with respect to each independent batched
    tensor, so the gradients are computed and combined appropriately.

    More information regarding the behavior of forward hooks with DataParallel models
    can be found in the PyTorch data parallel documentation. We maintain the separate
    evals in a dictionary protected by a lock, analogous to the gather implementation
    for the core PyTorch DataParallel implementation.
    """
    grad_enabled = True if gradient_neuron_selector is not None else grad_enabled

    with torch.autograd.set_grad_enabled(grad_enabled):
        saved_layer = _forward_layer_distributed_eval(
            forward_fn,
            inputs,
            layer,
            additional_forward_args=additional_forward_args,
            attribute_to_layer_input=attribute_to_layer_input,
        )
    device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
    # Identifies correct device ordering based on device ids.
    # key_list is a list of devices in appropriate ordering for concatenation.
    # If only one key exists (standard model), key list simply has one element.
    key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids)
    if gradient_neuron_selector is not None:
        assert isinstance(
            layer, Module
        ), "Cannot compute neuron gradients for multiple layers simultaneously!"
        inp_grads = _neuron_gradients(
            inputs, saved_layer[layer], key_list, gradient_neuron_selector
        )
        return (
            _gather_distributed_tensors(saved_layer[layer], key_list=key_list),
            inp_grads,
        )
    else:
        if isinstance(layer, Module):
            return _gather_distributed_tensors(saved_layer[layer], key_list=key_list)
        else:
            return [
                _gather_distributed_tensors(saved_layer[curr_layer], key_list=key_list)
                for curr_layer in layer
            ]


@typing.overload
def compute_layer_gradients_and_eval(
    forward_fn: Callable,
    layer: Module,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    *,
    gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    output_fn: Union[None, Callable] = None,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]:
    ...


@typing.overload
def compute_layer_gradients_and_eval(
    forward_fn: Callable,
    layer: List[Module],
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    gradient_neuron_selector: None = None,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    output_fn: Union[None, Callable] = None,
) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]:
    ...


@typing.overload
def compute_layer_gradients_and_eval(
    forward_fn: Callable,
    layer: Module,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    gradient_neuron_selector: None = None,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    output_fn: Union[None, Callable] = None,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
    ...


def compute_layer_gradients_and_eval(
    forward_fn: Callable,
    layer: ModuleOrModuleList,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
    gradient_neuron_selector: Union[
        None, int, Tuple[Union[int, slice], ...], Callable
    ] = None,
    device_ids: Union[None, List[int]] = None,
    attribute_to_layer_input: bool = False,
    output_fn: Union[None, Callable] = None,
) -> Union[
    Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
    Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]],
    Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]],
]:
    r"""
    Computes gradients of the output with respect to a given layer as well
    as the output evaluation of the layer for an arbitrary forward function
    and given input.

    For data parallel models, hooks are executed once per device ,so we
    need to internally combine the separated tensors from devices by
    concatenating based on device_ids. Any necessary gradients must be taken
    with respect to each independent batched tensor, so the gradients are
    computed and combined appropriately.

    More information regarding the behavior of forward hooks with DataParallel
    models can be found in the PyTorch data parallel documentation. We maintain
    the separate inputs in a dictionary protected by a lock, analogous to the
    gather implementation for the core PyTorch DataParallel implementation.

    NOTE: To properly handle inplace operations, a clone of the layer output
    is stored. This structure inhibits execution of a backward hook on the last
    module for the layer output when computing the gradient with respect to
    the input, since we store an intermediate clone, as
    opposed to the true module output. If backward module hooks are necessary
    for the final module when computing input gradients, utilize
    _forward_layer_eval_with_neuron_grads instead.

    Args:

        forward_fn: forward function. This can be for example model's
                    forward function.
        layer:      Layer for which gradients / output will be evaluated.
        inputs:     Input at which gradients are evaluated,
                    will be passed to forward_fn.
        target_ind: Index of the target class for which gradients
                    must be computed (classification only).
        output_fn:  An optional function that is applied to the layer inputs or
                    outputs depending whether the `attribute_to_layer_input` is
                    set to `True` or `False`
        args:       Additional input arguments that forward function requires.
                    It takes an empty tuple (no additional arguments) if no
                    additional arguments are required


    Returns:
        2-element tuple of **gradients**, **evals**:
        - **gradients**:
            Gradients of output with respect to target layer output.
        - **evals**:
            Target layer output for given input.
    """
    with torch.autograd.set_grad_enabled(True):
        # saved_layer is a dictionary mapping device to a tuple of
        # layer evaluations on that device.
        saved_layer, output = _forward_layer_distributed_eval(
            forward_fn,
            inputs,
            layer,
            target_ind=target_ind,
            additional_forward_args=additional_forward_args,
            attribute_to_layer_input=attribute_to_layer_input,
            forward_hook_with_return=True,
            require_layer_grads=True,
        )
        assert output[0].numel() == 1, (
            "Target not provided when necessary, cannot"
            " take gradient with respect to multiple outputs."
        )

        device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)

        # Identifies correct device ordering based on device ids.
        # key_list is a list of devices in appropriate ordering for concatenation.
        # If only one key exists (standard model), key list simply has one element.
        key_list = _sort_key_list(
            list(next(iter(saved_layer.values())).keys()), device_ids
        )
        all_outputs: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
        if isinstance(layer, Module):
            all_outputs = _reduce_list(
                [
                    saved_layer[layer][device_id]
                    if output_fn is None
                    else output_fn(saved_layer[layer][device_id])
                    for device_id in key_list
                ]
            )
        else:
            all_outputs = [
                _reduce_list(
                    [
                        saved_layer[single_layer][device_id]
                        if output_fn is None
                        else output_fn(saved_layer[single_layer][device_id])
                        for device_id in key_list
                    ]
                )
                for single_layer in layer
            ]
        all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer
        grad_inputs = tuple(
            layer_tensor
            for single_layer in all_layers
            for device_id in key_list
            for layer_tensor in saved_layer[single_layer][device_id]
        )
        saved_grads = torch.autograd.grad(torch.unbind(output), grad_inputs)

        offset = 0
        all_grads: List[Tuple[Tensor, ...]] = []
        for single_layer in all_layers:
            num_tensors = len(next(iter(saved_layer[single_layer].values())))
            curr_saved_grads = [
                saved_grads[i : i + num_tensors]
                for i in range(
                    offset, offset + len(key_list) * num_tensors, num_tensors
                )
            ]
            offset += len(key_list) * num_tensors
            if output_fn is not None:
                curr_saved_grads = [
                    output_fn(curr_saved_grad) for curr_saved_grad in curr_saved_grads
                ]

            all_grads.append(_reduce_list(curr_saved_grads))

        layer_grads: Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]
        layer_grads = all_grads
        if isinstance(layer, Module):
            layer_grads = all_grads[0]

        if gradient_neuron_selector is not None:
            assert isinstance(
                layer, Module
            ), "Cannot compute neuron gradients for multiple layers simultaneously!"
            inp_grads = _neuron_gradients(
                inputs, saved_layer[layer], key_list, gradient_neuron_selector
            )
            return (
                cast(Tuple[Tensor, ...], layer_grads),
                cast(Tuple[Tensor, ...], all_outputs),
                inp_grads,
            )
    return layer_grads, all_outputs  # type: ignore


def construct_neuron_grad_fn(
    layer: Module,
    neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
    device_ids: Union[None, List[int]] = None,
    attribute_to_neuron_input: bool = False,
) -> Callable:
    def grad_fn(
        forward_fn: Callable,
        inputs: TensorOrTupleOfTensorsGeneric,
        target_ind: TargetType = None,
        additional_forward_args: Any = None,
    ) -> Tuple[Tensor, ...]:
        _, grads = _forward_layer_eval_with_neuron_grads(
            forward_fn,
            inputs,
            layer,
            additional_forward_args,
            gradient_neuron_selector=neuron_selector,
            device_ids=device_ids,
            attribute_to_layer_input=attribute_to_neuron_input,
        )
        return grads

    return grad_fn


def _compute_jacobian_wrt_params(
    model: Module,
    inputs: Tuple[Any, ...],
    labels: Optional[Tensor] = None,
    loss_fn: Optional[Union[Module, Callable]] = None,
) -> Tuple[Tensor, ...]:
    r"""
    Computes the Jacobian of a batch of test examples given a model, and optional
    loss function and target labels. This method is equivalent to calculating the
    gradient for every individual example in the minibatch.

    Args:
        model (torch.nn.Module): The trainable model providing the forward pass
        inputs (tuple of Any): The minibatch for which the forward pass is computed.
                It is unpacked before passing to `model`, so it must be a tuple.  The
                individual elements of `inputs` can be anything.
        labels (Tensor or None): Labels for input if computing a loss function.
        loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
                defined loss function is provided, it would be expected to be a
                torch.nn.Module. If a custom loss is provided, it can be either type,
                but must behave as a library loss function would if `reduction='none'`.

    Returns:
        grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a
                tuple of gradients corresponding to the tuple of trainable parameters
                returned by `model.parameters()`. Each object grads[i] references to the
                gradients for the parameters in the i-th trainable layer of the model.
                Each grads[i] object is a tensor with the gradients for the `inputs`
                batch. For example, grads[i][j] would reference the gradients for the
                parameters of the i-th layer, for the j-th member of the minibatch.
    """
    with torch.autograd.set_grad_enabled(True):
        out = model(*inputs)
        assert out.dim() != 0, "Please ensure model output has at least one dimension."

        if labels is not None and loss_fn is not None:
            loss = loss_fn(out, labels)
            if hasattr(loss_fn, "reduction"):
                msg0 = "Please ensure loss_fn.reduction is set to `none`"
                assert loss_fn.reduction == "none", msg0  # type: ignore
            else:
                msg1 = (
                    "Loss function is applying a reduction. Please ensure "
                    f"Output shape: {out.shape} and Loss shape: {loss.shape} "
                    "are matching."
                )
                assert loss.dim() != 0, msg1
                assert out.shape[0] == loss.shape[0], msg1
            out = loss

        grads_list = [
            torch.autograd.grad(
                outputs=out[i],
                inputs=model.parameters(),  # type: ignore
                grad_outputs=torch.ones_like(out[i]),
                retain_graph=True,
            )
            for i in range(out.shape[0])
        ]

        grads = tuple([torch.stack(x) for x in zip(*grads_list)])

        return tuple(grads)


def _compute_jacobian_wrt_params_with_sample_wise_trick(
    model: Module,
    inputs: Tuple[Any, ...],
    labels: Optional[Tensor] = None,
    loss_fn: Optional[Union[Module, Callable]] = None,
    reduction_type: Optional[str] = "sum",
) -> Tuple[Any, ...]:
    r"""
    Computes the Jacobian of a batch of test examples given a model, and optional
    loss function and target labels. This method uses sample-wise gradients per
    batch trick to fully vectorize the Jacobian calculation. Currently, only
    linear and conv2d layers are supported.

    User must `add_hooks(model)` before calling this function.

    Args:
        model (torch.nn.Module): The trainable model providing the forward pass
        inputs (tuple of Any): The minibatch for which the forward pass is computed.
                It is unpacked before passing to `model`, so it must be a tuple.  The
                individual elements of `inputs` can be anything.
        labels (Tensor or None): Labels for input if computing a loss function.
        loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
                defined loss function is provided, it would be expected to be a
                torch.nn.Module. If a custom loss is provided, it can be either type,
                but must behave as a library loss function would if `reduction='sum'` or
                `reduction='mean'`.
        reduction_type (str): The type of reduction applied. If a loss_fn is passed,
                this should match `loss_fn.reduction`. Else if gradients are being
                computed on direct model outputs (scores), then 'sum' should be used.
                Defaults to 'sum'.

    Returns:
        grads (Tuple of Tensor): Returns the Jacobian for the minibatch as a
                tuple of gradients corresponding to the tuple of trainable parameters
                returned by `model.parameters()`. Each object grads[i] references to the
                gradients for the parameters in the i-th trainable layer of the model.
                Each grads[i] object is a tensor with the gradients for the `inputs`
                batch. For example, grads[i][j] would reference the gradients for the
                parameters of the i-th layer, for the j-th member of the minibatch.
    """
    with torch.autograd.set_grad_enabled(True):
        sample_grad_wrapper = SampleGradientWrapper(model)
        try:
            sample_grad_wrapper.add_hooks()

            out = model(*inputs)
            assert (
                out.dim() != 0
            ), "Please ensure model output has at least one dimension."

            if labels is not None and loss_fn is not None:
                loss = loss_fn(out, labels)
                # TODO: allow loss_fn to be Callable
                if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
                    msg0 = (
                        "Please ensure that loss_fn.reduction is set to `sum` or `mean`"
                    )

                    assert loss_fn.reduction != "none", msg0
                    msg1 = (
                        f"loss_fn.reduction ({loss_fn.reduction}) does not match"
                        f"reduction type ({reduction_type}). Please ensure they are"
                        " matching."
                    )
                    assert loss_fn.reduction == reduction_type, msg1
                msg2 = (
                    "Please ensure custom loss function is applying either a "
                    "sum or mean reduction."
                )
                assert out.shape != loss.shape, msg2

                if reduction_type != "sum" and reduction_type != "mean":
                    raise ValueError(
                        f"{reduction_type} is not a valid value for reduction_type. "
                        "Must be either 'sum' or 'mean'."
                    )
                out = loss

            sample_grad_wrapper.compute_param_sample_gradients(
                out, loss_mode=reduction_type
            )

            grads = tuple(
                param.sample_grad  # type: ignore
                for param in model.parameters()
                if hasattr(param, "sample_grad")
            )
        finally:
            sample_grad_wrapper.remove_hooks()

        return grads