Spaces:
Runtime error
Runtime error
File size: 7,433 Bytes
4fbd61f fee4293 4fbd61f 01d0d8e 4fbd61f 01d0d8e 4fbd61f abfd4d8 4fbd61f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning==1.6.5 einops wandb ftfy regex ./CLIP")
import argparse
from functools import partial
from pathlib import Path
import sys
sys.path.append('./cloob-latent-diffusion')
sys.path.append('./cloob-latent-diffusion/cloob-training')
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
sys.path.append('./cloob-latent-diffusion/taming-transformers')
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
from omegaconf import OmegaConf
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from CLIP import clip
from cloob_training import model_pt, pretrained
import ldm.models.autoencoder
from diffusion import sampling, utils
import train_latent_diffusion as train
from huggingface_hub import hf_hub_url, hf_hub_download
import random
# Download the model files
checkpoint = hf_hub_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
ae_model_path = hf_hub_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
ae_config_path = hf_hub_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
# Define a few utility functions
def parse_prompt(prompt, default_weight=3.):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', default_weight][len(vals):]
return vals[0], float(vals[1])
def resize_and_center_crop(image, size):
fac = max(size[0] / image.size[0], size[1] / image.size[1])
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])
# Load the models
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('loading models')
# autoencoder
ae_config = OmegaConf.load(ae_config_path)
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
ae_model.eval().requires_grad_(False).to(device)
ae_model.load_state_dict(torch.load(ae_model_path))
n_ch, side_y, side_x = 4, 32, 32
# diffusion model
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
model = model.to(device).eval().requires_grad_(False)
# CLOOB
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
cloob = model_pt.get_pt_model(cloob_config)
checkpoint = pretrained.download_checkpoint(cloob_config)
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
cloob.eval().requires_grad_(False).to(device)
# The key function: returns a list of n PIL images
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
method='plms', eta=None):
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
target_embeds, weights = [zero_embed], []
for prompt in prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
weights.append(weight)
for prompt in images:
path, weight = parse_prompt(prompt)
img = Image.open(utils.fetch(path)).convert('RGB')
clip_size = cloob.config['image_encoder']['image_size']
img = resize_and_center_crop(img, (clip_size, clip_size))
batch = TF.to_tensor(img)[None].to(device)
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
target_embeds.append(embed)
weights.append(weight)
weights = torch.tensor([1 - sum(weights), *weights], device=device)
torch.manual_seed(seed)
def cfg_model_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
v = vs.mul(weights[:, None, None, None, None]).sum(0)
return v
def run(x, steps):
if method == 'ddpm':
return sampling.sample(cfg_model_fn, x, steps, 1., {})
if method == 'ddim':
return sampling.sample(cfg_model_fn, x, steps, eta, {})
if method == 'prk':
return sampling.prk_sample(cfg_model_fn, x, steps, {})
if method == 'plms':
return sampling.plms_sample(cfg_model_fn, x, steps, {})
if method == 'pie':
return sampling.pie_sample(cfg_model_fn, x, steps, {})
if method == 'plms2':
return sampling.plms2_sample(cfg_model_fn, x, steps, {})
assert False
batch_size = n
x = torch.randn([n, n_ch, side_y, side_x], device=device)
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
steps = utils.get_spliced_ddpm_cosine_schedule(t)
pil_ims = []
for i in trange(0, n, batch_size):
cur_batch_size = min(n - i, batch_size)
out_latents = run(x[i:i+cur_batch_size], steps)
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
for j, out in enumerate(outs):
pil_ims.append(utils.to_pil_image(out))
return pil_ims
import gradio as gr
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
if seed == None :
seed = random.randint(0, 10000)
print( prompt, im_prompt, seed, n_steps)
prompts = [prompt]
im_prompts = []
if im_prompt != None:
im_prompts = [im_prompt]
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
return pil_ims[0]
iface = gr.Interface(fn=gen_ims,
inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
#gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
gr.inputs.Textbox(label="Text prompt"),
gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
#gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
],
outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
examples=[["An iceberg, oil on canvas"],["A martian landscape, in the style of Monet"], ['A peaceful meadow, pastel crayons'], ["A painting of a vase of flowers"], ["A ship leaving the port in the summer, oil on canvas"]],
title='Generate art from text prompts :',
description="By typing a text prompt or providing an image prompt, and pressing submit you can generate images based on this prompt. The model was trained on images from the [WikiArt](https://huggingface.co/datasets/huggan/wikiart) dataset, comprised mostly of paintings.",
article = 'The model is a distilled version of a cloob-conditioned latent diffusion model fine-tuned on the WikiArt dataset. You can find more information on this model on the [model card](https://huggingface.co/huggan/distill-ccld-wa). The student model training and this demo were done by [@gigant](https://huggingface.co/gigant). The teacher model was trained by [@johnowhitaker](https://huggingface.co/johnowhitaker)'
)
iface.launch(enable_queue=True) # , debug=True for colab debugging |