Vivien Chappelier commited on
Commit
464ec84
·
1 Parent(s): fbe5687

use packaged VAEs

Browse files
Files changed (1) hide show
  1. app.py +22 -30
app.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
 
8
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
10
- from diffusers import DiffusionPipeline
11
  import torchvision.transforms as transforms
12
 
13
  from copy import deepcopy
@@ -26,42 +26,35 @@ class BZHStableSignatureDemo(object):
26
 
27
  def __init__(self, *args, **kwargs):
28
  super().__init__(*args, **kwargs)
 
29
  self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
30
- try:
31
- print("self.pipe.watermark = ", self.pipe.watermark)
32
- except:
33
- print("no self.pipe.watermark")
34
-
35
- # load the patched VQ-VAEs
36
- sd1 = deepcopy(self.pipe.vae.state_dict()) # save initial state dict
37
- self.decoders = decoders = OrderedDict([("no watermark", sd1)])
38
- for name, patched_decoder_ckpt in (
39
- ("weak", "models/checkpoint_000.pth.50000"),
40
- ("medium", "models/checkpoint_000.pth.150000"),
41
- ("strong", "models/checkpoint_000.pth.500000"),
42
- ("extreme", "models/checkpoint_000.pth.1500000")):
43
- sd2 = torch.load(patched_decoder_ckpt)['ldm_decoder']
44
- msg = self.pipe.vae.load_state_dict(sd2, strict=False)
45
- print(f"loaded LDM decoder state_dict with message\n{msg}")
46
- print("you should check that the decoder keys are correctly matched")
47
- decoders[name] = sd2
48
  self.decoders = decoders
49
 
50
  def generate(self, mode, seed, prompt):
51
  generator = torch.Generator(device=device)
52
- #if seed:
53
  torch.manual_seed(seed)
54
 
55
- # load the patched VAE decoder
56
- sd = self.decoders[mode]
57
- self.pipe.vae.load_state_dict(sd, strict=False)
58
 
59
  output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
60
- return output.images[0] #{ "background": output.images[0], "layers": [], "composite": None }
61
 
62
  def attack_detect(self, img, jpeg_compression, downscale, crop, saturation):
63
 
64
- #img = img_edit["composite"]
65
  img = img.convert("RGB")
66
 
67
  # attack
@@ -69,6 +62,7 @@ class BZHStableSignatureDemo(object):
69
  size = img.size
70
  size = (int(size[0] / downscale), int(size[1] / downscale))
71
  img = img.resize(size, Image.Resampling.LANCZOS)
 
72
  if crop != 0:
73
  width, height = img.size
74
  area = width * height
@@ -108,17 +102,15 @@ class BZHStableSignatureDemo(object):
108
 
109
  mf.seek(0)
110
  img0 = Image.open(mf) # reload to show JPEG attack
111
- #result = "resolution = %dx%d p-value = %e" % (img.size[0], img.size[1], pvalue))
112
  result = "No watermark detected."
113
- chances = int(1 / pvalue + 1)
114
  rpv = 10**int(math.log10(pvalue))
115
  if pvalue < 1e-3:
116
- result = "Watermark detected with low confidence (p-value<%.0e)" % rpv # (< 1/%d chances of being wrong)" % chances
117
  if pvalue < 1e-9:
118
- result = "Watermark detected with high confidence (p-value<%.0e)" % rpv # (< 1/%d chances of being wrong)" % chances
119
  return (img0, result)
120
 
121
-
122
  def interface():
123
  prompt = "sailing ship in storm by Rembrandt"
124
 
 
7
 
8
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
 
10
+ from diffusers import DiffusionPipeline, AutoencoderKL
11
  import torchvision.transforms as transforms
12
 
13
  from copy import deepcopy
 
26
 
27
  def __init__(self, *args, **kwargs):
28
  super().__init__(*args, **kwargs)
29
+
30
  self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
31
+
32
+ # disable invisible-watermark
33
+ self.pipe.watermark = None
34
+
35
+ # save the original VAE
36
+ decoders = OrderedDict([("no watermark", self.pipe.vae)])
37
+
38
+ # load the patched VAEs
39
+ for name in ("weak", "medium", "strong", "extreme"):
40
+ vae = AutoencoderKL.from_pretrained(f"imatag/stable-signature-bzh-sdxl-vae-{name}", torch_dtype=torch.float16).to("cuda")
41
+ decoders[name] = vae
42
+
 
 
 
 
 
 
43
  self.decoders = decoders
44
 
45
  def generate(self, mode, seed, prompt):
46
  generator = torch.Generator(device=device)
 
47
  torch.manual_seed(seed)
48
 
49
+ # load the patched VAE
50
+ vae = self.decoders[mode]
51
+ self.pipe.vae = vae
52
 
53
  output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
54
+ return output.images[0]
55
 
56
  def attack_detect(self, img, jpeg_compression, downscale, crop, saturation):
57
 
 
58
  img = img.convert("RGB")
59
 
60
  # attack
 
62
  size = img.size
63
  size = (int(size[0] / downscale), int(size[1] / downscale))
64
  img = img.resize(size, Image.Resampling.LANCZOS)
65
+
66
  if crop != 0:
67
  width, height = img.size
68
  area = width * height
 
102
 
103
  mf.seek(0)
104
  img0 = Image.open(mf) # reload to show JPEG attack
105
+
106
  result = "No watermark detected."
 
107
  rpv = 10**int(math.log10(pvalue))
108
  if pvalue < 1e-3:
109
+ result = "Watermark detected with low confidence (p-value<%.0e)" % rpv
110
  if pvalue < 1e-9:
111
+ result = "Watermark detected with high confidence (p-value<%.0e)" % rpv
112
  return (img0, result)
113
 
 
114
  def interface():
115
  prompt = "sailing ship in storm by Rembrandt"
116