Pie31415's picture
update
cedb7e1
import torch
from torchvision import transforms
class Attack:
def __init__(self, pipe, classifer, device="cpu"):
self.device = device
self.pipe = pipe
self.generator = torch.Generator(device=self.device).manual_seed(1024)
self.classifer = classifer
def __call__(
self, prompt, negative_prompt="", size=512, guidance_scale=8, epsilon=0
):
pipe_output = self.pipe(
prompt=prompt, # What to generate
negative_prompt=negative_prompt, # What NOT to generate
height=size,
width=size, # Specify the image size
guidance_scale=guidance_scale, # How strongly to follow the prompt
num_inference_steps=30, # How many steps to take
generator=self.generator, # Fixed random seed
)
# Resulting image:
init_image = pipe_output.images[0]
image = self.transform(init_image)
image.requires_grad = True
outputs = self.classifer(image).to(self.device)
target = torch.tensor([0]).to(self.device)
return (
init_image,
self.untargeted_attack(image, outputs, target, epsilon),
)
def transform(self, image):
img_tfms = transforms.Compose(
[transforms.Resize(32), transforms.ToTensor()]
)
image = img_tfms(image)
image = torch.unsqueeze(image, dim=0)
return image
def untargeted_attack(self, image, pred, target, epsilon):
loss = torch.nn.functional.nll_loss(pred, target)
self.classifer.zero_grad()
loss.backward()
gradient_sign = image.grad.data.sign()
perturbed_image = image + epsilon * gradient_sign
perturbed_image = torch.clamp(perturbed_image, 0, 1)
return perturbed_image