File size: 4,603 Bytes
65abdbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import os
from typing import Mapping

import torch.nn.functional as F
from PIL import Image
from torch.nn import CrossEntropyLoss
from torchvision import transforms
from torchvision.utils import save_image
from transformers import BeitFeatureExtractor, BeitForImageClassification

from attacker import *

use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")


def make_args(args_=None):
    parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')

    parser.add_argument('inputs', type=str)
    parser.add_argument('--out_dir', type=str, default='./output')
    parser.add_argument('--target', type=str, default='auto', help='[auto, ai, human]')
    parser.add_argument('--eps', type=float, default=8 / 8, help='Noise intensity ')
    parser.add_argument('--step_size', type=float, default=1.087313 / 8, help='Attack step size')
    parser.add_argument('--steps', type=int, default=20, help='Attack step count')

    parser.add_argument('--test_atk', action='store_true')

    return parser.parse_args(args_)


IMAGE_EXTS = ('.bmp', '.dib', '.png', '.jpg', '.jpeg',
              '.pbm', '.pgm', '.ppm', '.tif', '.tiff')


class Attacker:
    def __init__(self, args):
        self.args = args
        os.makedirs(args.out_dir, exist_ok=True)

        print('正在加载模型...')
        self.feature_extractor = BeitFeatureExtractor.from_pretrained('saltacc/anime-ai-detect')
        self.model = BeitForImageClassification.from_pretrained('saltacc/anime-ai-detect')
        if use_gpu:
            self.model = self.model.cuda()
        print('加载完毕')

        if args.target == 'ai':  # 攻击成被识别为AI
            self.target = torch.tensor([1]).to(device)
        elif args.target == 'human':
            self.target = torch.tensor([0]).to(device)

        dataset_mean_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1)
        dataset_std_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1)
        if use_gpu:
            dataset_mean_t = dataset_mean_t.cuda()
            dataset_std_t = dataset_std_t.cuda()
        self.pgd = PGD(self.model, img_transform=(
            lambda x: (x - dataset_mean_t) / dataset_std_t, lambda x: x * dataset_std_t + dataset_mean_t))
        self.pgd.set_para(eps=(args.eps * 2) / 255, alpha=lambda: (args.step_size * 2) / 255, iters=args.steps)
        self.pgd.set_loss(CrossEntropyLoss())

    def save_image(self, image, noise, img_name):
        # 缩放图片只缩放噪声
        W, H = image.size
        noise = F.interpolate(noise, size=(H, W), mode='bicubic')
        img_save = transforms.ToTensor()(image) + noise
        save_image(img_save, os.path.join(self.args.out_dir, f'{img_name[:img_name.rfind(".")]}_atk.png'))

    def attack_(self, image, step_func=None):
        inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values']
        if use_gpu:
            inputs = inputs.cuda()

        if self.args.target == 'auto':
            with torch.no_grad():
                outputs = self.model(inputs)
                logits = outputs.logits
                cls = logits.argmax(-1).item()
                target = torch.tensor([cls]).to(device)
        else:
            target = self.target

        if self.args.test_atk:
            self.test_image(inputs, 'before attack')

        atk_img = self.pgd.attack(inputs, target, step_func)

        noise = self.pgd.img_transform[1](atk_img).detach().cpu() - self.pgd.img_transform[1](inputs).detach().cpu()

        if self.args.test_atk:
            self.test_image(atk_img, 'after attack')

        return atk_img, noise

    def attack_one(self, path, step_func=None):
        image = Image.open(path).convert('RGB')
        atk_img, noise = self.attack_(image, step_func)
        self.save_image(image, noise, os.path.basename(path))

    def attack(self, path, step_func=None):
        self.attack_one(path, step_func)

    @torch.no_grad()
    def test_image(self, img, pre_fix=None):
        outputs = self.model(img)
        logits = outputs.logits
        _ = pre_fix
        confidences = torch.softmax(logits.reshape(-1), dim=0)
        return {self.model.config.id2label[i]: float(conf) for i, conf in enumerate(confidences)}

    @torch.no_grad()
    def image_predict(self, image: Image.Image) -> Mapping[str, float]:
        inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values']
        return self.test_image(inputs)


if __name__ == '__main__':
    args = make_args()
    attacker = Attacker(args)
    attacker.attack(args.inputs)