File size: 5,712 Bytes
0b00c74
d2794b1
ca86cf6
8b44d8d
 
ca86cf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab4f056
 
 
 
ca86cf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b44d8d
ca86cf6
 
 
 
 
 
 
 
 
 
8b44d8d
 
ca86cf6
 
8b44d8d
ca86cf6
 
d2794b1
ca86cf6
 
 
 
 
 
0b00c74
ca86cf6
 
 
 
 
 
 
ac121a8
ca86cf6
 
 
 
 
 
 
 
 
 
0b00c74
ca86cf6
0b00c74
ca86cf6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr

import os
import torch

import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from diffusers import DiffusionPipeline
import torchvision.transforms as transforms

from copy import deepcopy
from collections import OrderedDict

import requests
import json

from PIL import Image, ImageEnhance
import base64
import io

class BZHStableSignatureDemo(object):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")

        # load the patched VQ-VAEs
        sd1 = deepcopy(self.pipe.vae.state_dict()) # save initial state dict
        self.decoders = decoders = OrderedDict([("no watermark", sd1)])
        for name, patched_decoder_ckpt in (
                ("weak", "models/checkpoint_000.pth.50000"),
                ("medium", "models/checkpoint_000.pth.150000"),
                ("strong", "models/checkpoint_000.pth.500000"),
                ("extreme", "models/checkpoint_000.pth.1500000")):
            sd2 = torch.load(patched_decoder_ckpt)['ldm_decoder']
            msg = self.pipe.vae.load_state_dict(sd2, strict=False)
            print(f"loaded LDM decoder state_dict with message\n{msg}")
            print("you should check that the decoder keys are correctly matched")
            decoders[name] = sd2
        self.decoders = decoders

    def generate(self, mode, seed, prompt):
        generator = torch.Generator(device=device)
        if seed:
            torch.manual_seed(seed)

        # load the patched VAE decoder
        sd = self.decoders[mode]
        self.pipe.vae.load_state_dict(sd, strict=False)

        output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
        return output.images[0]

    @staticmethod
    def pad(img, padding, mode="edge"):
        npimg = np.asarray(img)
        nppad = ((padding[1], padding[3]), (padding[0], padding[2]), (0,0))
        npimg = np.pad(npimg, nppad, mode=mode)
        return Image.fromarray(npimg)

    def attack_detect(self, img, jpeg_compression, downscale, saturation):

        # attack
        if downscale != 1:
            size = img.size
            size = (int(size[0] / downscale), int(size[1] / downscale))
            img = img.resize(size, Image.BICUBIC)

        converter = ImageEnhance.Color(img)
        img = converter.enhance(saturation)
        
        # send to detection API and apply JPEG compression attack
        mf = io.BytesIO()
        img.save(mf, format='JPEG', quality=jpeg_compression) # includes JPEG attack
        b64 = base64.b64encode(mf.getvalue())
        data = {
            'image': b64.decode('utf8')
        }
        
        headers = {}
        api_key = os.environ.get('BZH_API_KEY', None)
        if api_key:
            headers['BZH_API_KEY'] = api_key
        response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect',
                                 json=data, headers=headers)
        response.raise_for_status()
        data = response.json()
        pvalue = data['p-value']

        mf.seek(0)
        img0 = Image.open(mf) # reload to show JPEG attack
        #result = "resolution = %dx%d  p-value = %e" % (img.size[0], img.size[1], pvalue))
        result = "No watermark detected."
        chances = int(1 / pvalue + 1)
        if pvalue < 1e-3:
            result = "Weak watermark detected (< 1/%d chances of being wrong)" % chances
        if pvalue < 1e-6:
            result = "Strong watermark detected (< 1/%d chances of being wrong)" % chances
        return (img0, result)


def interface():
    prompt = "sailing ship in storm by Rembrandt"

    backend = BZHStableSignatureDemo()
    decoders = list(backend.decoders.keys())

    with gr.Blocks() as demo:
        gr.Markdown("""# Watermarked SDXL-Turbo demo
        This demo presents watermarking of images generated via StableDiffusion XL Turbo.
        Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/),
        the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined
        this method with our in-house decoder which operates in zero-bit mode for improved robustness.""")

        with gr.Row():
            inp = gr.Textbox(label="Prompt", value=prompt)
            seed = gr.Number(label="Seed", precision=0)
            mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium")
        with gr.Row():
            btn1 = gr.Button("Generate")
        with gr.Row():
            watermarked_image = gr.Image(type="pil").style(width=512, height=512)
            with gr.Column():
                downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio")
                saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation")
                jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality")
                btn2 = gr.Button("Attack & Detect")
                with gr.Row():
                    attacked_image = gr.Image(type="pil", tool="select").style(width=256)
                    detection_label = gr.Label(label="Detection info")
        btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
        btn2.click(fn=backend.attack_detect, inputs=[watermarked_image, jpeg_compression, downscale, saturation], outputs=[attacked_image, detection_label], api_name="detect")

    return demo

if __name__ == '__main__':
    demo = interface()
    demo.launch()