File size: 9,031 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#!/usr/bin/env python3
from typing import Any, Callable

import torch
import torch.nn.functional as F
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.robust._core.fgsm import FGSM
from captum.robust._core.perturbation import Perturbation
from torch import Tensor


class PGD(Perturbation):
    r"""
    Projected Gradient Descent is an iterative version of the one-step attack
    FGSM that can generate adversarial examples. It takes multiple gradient
    steps to search for an adversarial perturbation within the desired
    neighbor ball around the original inputs. In a non-targeted attack, the
    formulation is::

        x_0 = x
        x_(t+1) = Clip_r(x_t + alpha * sign(gradient of L(theta, x, t)))

    where Clip denotes the function that projects its argument to the r-neighbor
    ball around x so that the perturbation will be bounded. Alpha is the step
    size. L(theta, x, y) is the model's loss function with respect to model
    parameters, inputs and targets.
    In a targeted attack, the formulation is similar::

        x_0 = x
        x_(t+1) = Clip_r(x_t - alpha * sign(gradient of L(theta, x, t)))

    More details on Projected Gradient Descent can be found in the original
    paper:
    https://arxiv.org/pdf/1706.06083.pdf
    """

    def __init__(
        self,
        forward_func: Callable,
        loss_func: Callable = None,
        lower_bound: float = float("-inf"),
        upper_bound: float = float("inf"),
    ) -> None:
        r"""
        Args:
            forward_func (callable): The pytorch model for which the attack is
                        computed.
            loss_func (callable, optional): Loss function of which the gradient
                        computed. The loss function should take in outputs of the
                        model and labels, and return the loss for each input tensor.
                        The default loss function is negative log.
            lower_bound (float, optional): Lower bound of input values.
            upper_bound (float, optional): Upper bound of input values.
                        e.g. image pixels must be in the range 0-255

        Attributes:
            bound (Callable): A function that bounds the input values based on
                        given lower_bound and upper_bound. Can be overwritten for
                        custom use cases if necessary.
        """
        super().__init__()
        self.forward_func = forward_func
        self.fgsm = FGSM(forward_func, loss_func)
        self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)

    def perturb(
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        radius: float,
        step_size: float,
        step_num: int,
        target: Any,
        additional_forward_args: Any = None,
        targeted: bool = False,
        random_start: bool = False,
        norm: str = "Linf",
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        This method computes and returns the perturbed input for each input tensor.
        It supports both targeted and non-targeted attacks.

        Args:

            inputs (tensor or tuple of tensors): Input for which adversarial
                        attack is computed. It can be provided as a single
                        tensor or a tuple of multiple tensors. If multiple
                        input tensors are provided, the batch sizes must be
                        aligned accross all tensors.
            radius (float): Radius of the neighbor ball centered around inputs.
                        The perturbation should be within this range.
            step_size (float): Step size of each gradient step.
            step_num (int): Step numbers. It usually guarantees that the perturbation
                        can reach the border.
            target (any): True labels of inputs if non-targeted attack is
                        desired. Target class of inputs if targeted attack
                        is desired. Target will be passed to the loss function
                        to compute loss, so the type needs to match the
                        argument type of the loss function.

                        If using the default negative log as loss function,
                        labels should be of type int, tuple, tensor or list.
                        For general 2D outputs, labels can be either:

                        - a single integer or a tensor containing a single
                          integer, which is applied to all input examples

                        - a list of integers or a 1D tensor, with length matching
                          the number of examples in inputs (dim 0). Each integer
                          is applied as the label for the corresponding example.

                        For outputs with > 2 dimensions, labels can be either:

                        - A single tuple, which contains #output_dims - 1
                          elements. This label index is applied to all examples.

                        - A list of tuples with length equal to the number of
                          examples in inputs (dim 0), and each tuple containing
                          #output_dims - 1 elements. Each tuple is applied as the
                          label for the corresponding example.
            additional_forward_args (any, optional): If the forward function
                        requires additional arguments other than the inputs for
                        which attributions should not be computed, this argument
                        can be provided. These arguments are provided to
                        forward_func in order following the arguments in inputs.
                        Default: None.
            targeted (bool, optional): If attack should be targeted.
                        Default: False.
            random_start (bool, optional): If a random initialization is added to
                        inputs. Default: False.
            norm (str, optional): Specifies the norm to calculate distance from
                        original inputs: 'Linf'|'L2'.
                        Default: 'Linf'.

        Returns:

            - **perturbed inputs** (*tensor* or tuple of *tensors*):
                        Perturbed input for each
                        input tensor. The perturbed inputs have the same shape and
                        dimensionality as the inputs.
                        If a single tensor is provided as inputs, a single tensor
                        is returned. If a tuple is provided for inputs, a tuple of
                        corresponding sized tensors is returned.
        """

        def _clip(inputs: Tensor, outputs: Tensor) -> Tensor:
            diff = outputs - inputs
            if norm == "Linf":
                return inputs + torch.clamp(diff, -radius, radius)
            elif norm == "L2":
                return inputs + torch.renorm(diff, 2, 0, radius)
            else:
                raise AssertionError("Norm constraint must be L2 or Linf.")

        is_inputs_tuple = _is_tuple(inputs)
        formatted_inputs = _format_tensor_into_tuples(inputs)
        perturbed_inputs = formatted_inputs
        if random_start:
            perturbed_inputs = tuple(
                self.bound(self._random_point(formatted_inputs[i], radius, norm))
                for i in range(len(formatted_inputs))
            )
        for _i in range(step_num):
            perturbed_inputs = self.fgsm.perturb(
                perturbed_inputs, step_size, target, additional_forward_args, targeted
            )
            perturbed_inputs = tuple(
                _clip(formatted_inputs[j], perturbed_inputs[j])
                for j in range(len(perturbed_inputs))
            )
            # Detaching inputs to avoid dependency of gradient between steps
            perturbed_inputs = tuple(
                self.bound(perturbed_inputs[j]).detach()
                for j in range(len(perturbed_inputs))
            )
        return _format_output(is_inputs_tuple, perturbed_inputs)

    def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor:
        r"""
        A helper function that returns a uniform random point within the ball
        with the given center and radius. Norm should be either L2 or Linf.
        """
        if norm == "L2":
            u = torch.randn_like(center)
            unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size())
            d = torch.numel(center[0])
            r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius
            r = r[(...,) + (None,) * (r.dim() - 1)]
            x = r * unit_u
            return center + x
        elif norm == "Linf":
            x = torch.rand_like(center) * radius * 2 - radius
            return center + x
        else:
            raise AssertionError("Norm constraint must be L2 or Linf.")