import torch from torchvision import transforms class Attack: def __init__(self, pipe, classifer, device="cpu"): self.device = device self.pipe = pipe self.generator = torch.Generator(device=self.device).manual_seed(1024) self.classifer = classifer def __call__( self, prompt, negative_prompt="", size=512, guidance_scale=8, epsilon=0 ): pipe_output = self.pipe( prompt=prompt, # What to generate negative_prompt=negative_prompt, # What NOT to generate height=size, width=size, # Specify the image size guidance_scale=guidance_scale, # How strongly to follow the prompt num_inference_steps=30, # How many steps to take generator=self.generator, # Fixed random seed ) # Resulting image: init_image = pipe_output.images[0] image = self.transform(init_image) image.requires_grad = True outputs = self.classifer(image).to(self.device) target = torch.tensor([0]).to(self.device) return ( init_image, self.untargeted_attack(image, outputs, target, epsilon), ) def transform(self, image): img_tfms = transforms.Compose( [transforms.Resize(32), transforms.ToTensor()] ) image = img_tfms(image) image = torch.unsqueeze(image, dim=0) return image def untargeted_attack(self, image, pred, target, epsilon): loss = torch.nn.functional.nll_loss(pred, target) self.classifer.zero_grad() loss.backward() gradient_sign = image.grad.data.sign() perturbed_image = image + epsilon * gradient_sign perturbed_image = torch.clamp(perturbed_image, 0, 1) return perturbed_image