|
import gradio as gr |
|
from PIL import Image |
|
|
|
import torch |
|
import re |
|
import os |
|
import requests |
|
|
|
from customization import customize_vae_decoder |
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler |
|
from torchvision import transforms |
|
from attribution import MappingNetwork |
|
|
|
import math |
|
from typing import List |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
|
|
|
|
PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/" |
|
|
|
|
|
def get_image_grid(images: List[Image.Image]) -> Image: |
|
num_images = len(images) |
|
cols = 3 |
|
rows = 1 |
|
width, height = images[0].size |
|
grid_image = Image.new('RGB', (cols * width, rows * height)) |
|
for i, img in enumerate(images): |
|
x = i % cols |
|
y = i // cols |
|
grid_image.paste(img, (x * width, y * height)) |
|
return grid_image |
|
|
|
|
|
class AttributionModel: |
|
def __init__(self): |
|
is_cuda = False |
|
if torch.cuda.is_available(): |
|
is_cuda = True |
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2') |
|
if is_cuda: |
|
self.pipe = self.pipe.to("cuda") |
|
self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR) |
|
self.vae = AutoencoderKL.from_pretrained( |
|
'stabilityai/stable-diffusion-2', subfolder="vae" |
|
) |
|
self.vae = customize_vae_decoder(self.vae, 128, "qkv", "all", False, 1.0) |
|
|
|
self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False).to("cuda") |
|
|
|
from torchvision.models import resnet50, ResNet50_Weights |
|
self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) |
|
self.decoding_network.fc = torch.nn.Linear(2048,32) |
|
|
|
self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth'))) |
|
self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth'))) |
|
self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth'))) |
|
|
|
if is_cuda: |
|
self.vae = self.vae.to("cuda") |
|
self.mapping_network = self.mapping_network.to("cuda") |
|
self.decoding_network = self.decoding_network.to("cuda") |
|
|
|
self.test_norm = transforms.Compose( |
|
[ |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
] |
|
) |
|
|
|
def infer(self, prompt, negative, steps, guidance_scale): |
|
with torch.no_grad(): |
|
out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images |
|
image_attr = self.inference_with_attribution(out_latents) |
|
image_attr_pil = self.pipe.numpy_to_pil(image_attr[0]) |
|
|
|
image_org = self.inference_without_attribution(out_latents) |
|
image_org_pil = self.pipe.numpy_to_pil(image_org[0]) |
|
|
|
image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0]) |
|
|
|
return image_org_pil[0], image_attr_pil[0], image_diff_pil[0] |
|
|
|
def inference_without_attribution(self, latents): |
|
latents = 1 / 0.18215 * latents |
|
with torch.no_grad(): |
|
image = self.pipe.vae.decode(latents).sample |
|
image = image.clamp(-1,1) |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
|
return image |
|
|
|
def get_phis(self, phi_dimension, batch_size ,eps = 1e-8): |
|
phi_length = phi_dimension |
|
b = batch_size |
|
phi = torch.empty(b,phi_length).uniform_(0,1) |
|
return torch.bernoulli(phi) + eps |
|
|
|
|
|
def inference_with_attribution(self, latents, key=None): |
|
if key==None: |
|
key = self.get_phis(32, 1) |
|
|
|
latents = 1 / 0.18215 * latents |
|
with torch.no_grad(): |
|
image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample |
|
image = image.clamp(-1,1) |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
|
return image |
|
|
|
def postprocess(self, image): |
|
image = self.resize_transform(image) |
|
return image |
|
|
|
def detect_key(self, image): |
|
reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1))) |
|
return reconstructed_keys |
|
|
|
|
|
attribution_model = AttributionModel() |
|
def get_images(prompt, negative, steps, guidence_scale): |
|
x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale) |
|
return [x1, x2, x3] |
|
|
|
|
|
image_examples = [ |
|
["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10], |
|
["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10] |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"""<h1 style="text-align: center;"><b>WOUAF: |
|
Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion |
|
-Models/">Project Page</a></h1>""") |
|
|
|
gr.Markdown( |
|
"""<h3>Demo: Text-to-Image (Stable diffusion 2) with random user attribution</h3> |
|
WOUAF can be applied to other applications such as In-painting, Image-editing, Image Super-Resolution etc. |
|
<br>More details at: <a href="https://arxiv.org/abs/2306.04744">Paper</a> |
|
""" |
|
) |
|
|
|
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): |
|
with gr.Column(): |
|
text = gr.Textbox( |
|
label="Enter your prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt", |
|
elem_id="prompt-text-input", |
|
).style( |
|
border=(True, False, True, True), |
|
rounded=(True, False, False, True), |
|
container=False, |
|
) |
|
negative = gr.Textbox( |
|
label="Enter your negative prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter a negative prompt", |
|
elem_id="negative-prompt-text-input", |
|
).style( |
|
border=(True, False, True, True), |
|
rounded=(True, False, False, True), |
|
container=False, |
|
) |
|
|
|
with gr.Row(): |
|
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1) |
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale", minimum=0, maximum=10, value=9, step=0.1 |
|
) |
|
|
|
with gr.Row(): |
|
btn = gr.Button(value="Generate Image", full_width=False) |
|
|
|
with gr.Row(): |
|
im_2 = gr.Image(type="pil", label="without attribution") |
|
im_3 = gr.Image(type="pil", label="**with** attribution") |
|
im_4 = gr.Image(type="pil", label="pixel-wise difference") |
|
|
|
|
|
btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4]) |
|
|
|
gr.Examples( |
|
examples=image_examples, |
|
inputs=[text, negative, steps, guidance_scale], |
|
outputs=[im_2, im_3, im_4], |
|
fn=get_images, |
|
cache_examples=True, |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<div class="footer"> |
|
<p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> |
|
</p> |
|
<p> |
|
Fine-tuned by authors for research purpose. |
|
</p> |
|
</div> |
|
""" |
|
) |
|
with gr.Accordion(label="Ethics & Privacy", open=False): |
|
gr.HTML( |
|
"""<div class="acknowledgments"> |
|
<p><h4>Privacy</h4> |
|
We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI. |
|
<p><h4>Biases and content acknowledgment</h4> |
|
This model will have the same biases as Stable Diffusion V2.1 </div> |
|
""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|