Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from typing import Any, Callable | |
import torch | |
import torch.nn.functional as F | |
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple | |
from captum._utils.typing import TensorOrTupleOfTensorsGeneric | |
from captum.robust._core.fgsm import FGSM | |
from captum.robust._core.perturbation import Perturbation | |
from torch import Tensor | |
class PGD(Perturbation): | |
r""" | |
Projected Gradient Descent is an iterative version of the one-step attack | |
FGSM that can generate adversarial examples. It takes multiple gradient | |
steps to search for an adversarial perturbation within the desired | |
neighbor ball around the original inputs. In a non-targeted attack, the | |
formulation is:: | |
x_0 = x | |
x_(t+1) = Clip_r(x_t + alpha * sign(gradient of L(theta, x, t))) | |
where Clip denotes the function that projects its argument to the r-neighbor | |
ball around x so that the perturbation will be bounded. Alpha is the step | |
size. L(theta, x, y) is the model's loss function with respect to model | |
parameters, inputs and targets. | |
In a targeted attack, the formulation is similar:: | |
x_0 = x | |
x_(t+1) = Clip_r(x_t - alpha * sign(gradient of L(theta, x, t))) | |
More details on Projected Gradient Descent can be found in the original | |
paper: | |
https://arxiv.org/pdf/1706.06083.pdf | |
""" | |
def __init__( | |
self, | |
forward_func: Callable, | |
loss_func: Callable = None, | |
lower_bound: float = float("-inf"), | |
upper_bound: float = float("inf"), | |
) -> None: | |
r""" | |
Args: | |
forward_func (callable): The pytorch model for which the attack is | |
computed. | |
loss_func (callable, optional): Loss function of which the gradient | |
computed. The loss function should take in outputs of the | |
model and labels, and return the loss for each input tensor. | |
The default loss function is negative log. | |
lower_bound (float, optional): Lower bound of input values. | |
upper_bound (float, optional): Upper bound of input values. | |
e.g. image pixels must be in the range 0-255 | |
Attributes: | |
bound (Callable): A function that bounds the input values based on | |
given lower_bound and upper_bound. Can be overwritten for | |
custom use cases if necessary. | |
""" | |
super().__init__() | |
self.forward_func = forward_func | |
self.fgsm = FGSM(forward_func, loss_func) | |
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound) | |
def perturb( | |
self, | |
inputs: TensorOrTupleOfTensorsGeneric, | |
radius: float, | |
step_size: float, | |
step_num: int, | |
target: Any, | |
additional_forward_args: Any = None, | |
targeted: bool = False, | |
random_start: bool = False, | |
norm: str = "Linf", | |
) -> TensorOrTupleOfTensorsGeneric: | |
r""" | |
This method computes and returns the perturbed input for each input tensor. | |
It supports both targeted and non-targeted attacks. | |
Args: | |
inputs (tensor or tuple of tensors): Input for which adversarial | |
attack is computed. It can be provided as a single | |
tensor or a tuple of multiple tensors. If multiple | |
input tensors are provided, the batch sizes must be | |
aligned accross all tensors. | |
radius (float): Radius of the neighbor ball centered around inputs. | |
The perturbation should be within this range. | |
step_size (float): Step size of each gradient step. | |
step_num (int): Step numbers. It usually guarantees that the perturbation | |
can reach the border. | |
target (any): True labels of inputs if non-targeted attack is | |
desired. Target class of inputs if targeted attack | |
is desired. Target will be passed to the loss function | |
to compute loss, so the type needs to match the | |
argument type of the loss function. | |
If using the default negative log as loss function, | |
labels should be of type int, tuple, tensor or list. | |
For general 2D outputs, labels can be either: | |
- a single integer or a tensor containing a single | |
integer, which is applied to all input examples | |
- a list of integers or a 1D tensor, with length matching | |
the number of examples in inputs (dim 0). Each integer | |
is applied as the label for the corresponding example. | |
For outputs with > 2 dimensions, labels can be either: | |
- A single tuple, which contains #output_dims - 1 | |
elements. This label index is applied to all examples. | |
- A list of tuples with length equal to the number of | |
examples in inputs (dim 0), and each tuple containing | |
#output_dims - 1 elements. Each tuple is applied as the | |
label for the corresponding example. | |
additional_forward_args (any, optional): If the forward function | |
requires additional arguments other than the inputs for | |
which attributions should not be computed, this argument | |
can be provided. These arguments are provided to | |
forward_func in order following the arguments in inputs. | |
Default: None. | |
targeted (bool, optional): If attack should be targeted. | |
Default: False. | |
random_start (bool, optional): If a random initialization is added to | |
inputs. Default: False. | |
norm (str, optional): Specifies the norm to calculate distance from | |
original inputs: 'Linf'|'L2'. | |
Default: 'Linf'. | |
Returns: | |
- **perturbed inputs** (*tensor* or tuple of *tensors*): | |
Perturbed input for each | |
input tensor. The perturbed inputs have the same shape and | |
dimensionality as the inputs. | |
If a single tensor is provided as inputs, a single tensor | |
is returned. If a tuple is provided for inputs, a tuple of | |
corresponding sized tensors is returned. | |
""" | |
def _clip(inputs: Tensor, outputs: Tensor) -> Tensor: | |
diff = outputs - inputs | |
if norm == "Linf": | |
return inputs + torch.clamp(diff, -radius, radius) | |
elif norm == "L2": | |
return inputs + torch.renorm(diff, 2, 0, radius) | |
else: | |
raise AssertionError("Norm constraint must be L2 or Linf.") | |
is_inputs_tuple = _is_tuple(inputs) | |
formatted_inputs = _format_tensor_into_tuples(inputs) | |
perturbed_inputs = formatted_inputs | |
if random_start: | |
perturbed_inputs = tuple( | |
self.bound(self._random_point(formatted_inputs[i], radius, norm)) | |
for i in range(len(formatted_inputs)) | |
) | |
for _i in range(step_num): | |
perturbed_inputs = self.fgsm.perturb( | |
perturbed_inputs, step_size, target, additional_forward_args, targeted | |
) | |
perturbed_inputs = tuple( | |
_clip(formatted_inputs[j], perturbed_inputs[j]) | |
for j in range(len(perturbed_inputs)) | |
) | |
# Detaching inputs to avoid dependency of gradient between steps | |
perturbed_inputs = tuple( | |
self.bound(perturbed_inputs[j]).detach() | |
for j in range(len(perturbed_inputs)) | |
) | |
return _format_output(is_inputs_tuple, perturbed_inputs) | |
def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor: | |
r""" | |
A helper function that returns a uniform random point within the ball | |
with the given center and radius. Norm should be either L2 or Linf. | |
""" | |
if norm == "L2": | |
u = torch.randn_like(center) | |
unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size()) | |
d = torch.numel(center[0]) | |
r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius | |
r = r[(...,) + (None,) * (r.dim() - 1)] | |
x = r * unit_u | |
return center + x | |
elif norm == "Linf": | |
x = torch.rand_like(center) * radius * 2 - radius | |
return center + x | |
else: | |
raise AssertionError("Norm constraint must be L2 or Linf.") | |