File size: 1,371 Bytes
269cbe7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
#!/usr/bin/env python3
import tree_ring_watermark as trk
from diffusers import DiffusionPipeline, DDIMScheduler
from pathlib import Path
from huggingface_hub import HfApi, login
import torch
# login() # make sure you login it with on account that is connected to `trk-demo`
trk.set_org("trk-demo")
model_id = 'stabilityai/stable-diffusion-2-1-base'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# note that the model hash should be the latest commit hash of the repo's history: https://huggingface.co/stabilityai/stable-diffusion-2-base/commits/main
model_hash = "dcd3ee64f0c1aba2eb9e0c0c16041c6cae40d780"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
# get noise
batch_size = 1
n_channels = pipe.unet.config.in_channels
sample_size = pipe.unet.config.sample_size
shape = (batch_size, n_channels, sample_size, sample_size)
# get model hash from https://huggingface.co/stabilityai/stable-diffusion-2-1-base/commits/main
latents = trk.get_noise(shape, model_hash=model_hash)
latents = latents.to(device=pipe.device, dtype=torch.float16)
# generation without watermarking
image = pipe(prompt="an astronaut", latents=latents).images[0]
is_watermarked = trk.detect(image, pipe, model_hash)
print(f'is_watermarked: {is_watermarked}')
|