File size: 1,371 Bytes
269cbe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!/usr/bin/env python3
import tree_ring_watermark as trk
from diffusers import DiffusionPipeline, DDIMScheduler
from pathlib import Path
from huggingface_hub import HfApi, login
import torch


# login()  # make sure you login it with on account that is connected to `trk-demo`
trk.set_org("trk-demo")

model_id = 'stabilityai/stable-diffusion-2-1-base'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# note that the model hash should be the latest commit hash of the repo's history: https://huggingface.co/stabilityai/stable-diffusion-2-base/commits/main
model_hash = "dcd3ee64f0c1aba2eb9e0c0c16041c6cae40d780"

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)

# get noise
batch_size = 1
n_channels = pipe.unet.config.in_channels
sample_size = pipe.unet.config.sample_size 

shape = (batch_size, n_channels, sample_size, sample_size)

# get model hash from https://huggingface.co/stabilityai/stable-diffusion-2-1-base/commits/main
latents = trk.get_noise(shape, model_hash=model_hash)
latents = latents.to(device=pipe.device, dtype=torch.float16)

# generation without watermarking
image = pipe(prompt="an astronaut", latents=latents).images[0]

is_watermarked = trk.detect(image, pipe, model_hash)
print(f'is_watermarked: {is_watermarked}')