|
import torch |
|
from torchvision import transforms |
|
|
|
|
|
class Attack: |
|
def __init__(self, pipe, classifer, device="cpu"): |
|
self.device = device |
|
self.pipe = pipe |
|
self.generator = torch.Generator(device=self.device).manual_seed(1024) |
|
self.classifer = classifer |
|
|
|
def __call__( |
|
self, prompt, negative_prompt="", size=512, guidance_scale=8, epsilon=0 |
|
): |
|
pipe_output = self.pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
height=size, |
|
width=size, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=30, |
|
generator=self.generator, |
|
) |
|
|
|
|
|
init_image = pipe_output.images[0] |
|
image = self.transform(init_image) |
|
|
|
image.requires_grad = True |
|
|
|
outputs = self.classifer(image).to(self.device) |
|
|
|
target = torch.tensor([0]).to(self.device) |
|
|
|
return ( |
|
init_image, |
|
self.untargeted_attack(image, outputs, target, epsilon), |
|
) |
|
|
|
def transform(self, image): |
|
img_tfms = transforms.Compose( |
|
[transforms.Resize(32), transforms.ToTensor()] |
|
) |
|
image = img_tfms(image) |
|
image = torch.unsqueeze(image, dim=0) |
|
return image |
|
|
|
def untargeted_attack(self, image, pred, target, epsilon): |
|
loss = torch.nn.functional.nll_loss(pred, target) |
|
|
|
self.classifer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
gradient_sign = image.grad.data.sign() |
|
perturbed_image = image + epsilon * gradient_sign |
|
perturbed_image = torch.clamp(perturbed_image, 0, 1) |
|
|
|
return perturbed_image |
|
|