Spaces:
Runtime error
Runtime error
Commit
·
755e849
1
Parent(s):
b2f2472
Request to update and reboot the app 5.31.22
Browse files
app.py
CHANGED
@@ -1,175 +1,4 @@
|
|
1 |
-
|
2 |
-
import streamlit as st
|
3 |
-
|
4 |
-
st.title("
|
5 |
-
stapp.py --installer cloob_latent_diffusion=https://github.com/JD-P/cloob-latent-diffusion/archive/master.zip --installer megatron=https://github.com/NVIDIA/MegatronLM/archive/master.zip --installer huggingface=https://github.com/huggingface/transformers
|
6 |
-
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
|
7 |
-
|
8 |
-
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
|
9 |
-
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
|
10 |
-
|
11 |
-
import argparse
|
12 |
-
from functools import partial
|
13 |
-
from pathlib import Path
|
14 |
-
import sys
|
15 |
-
sys.path.append('./cloob-latent-diffusion')
|
16 |
-
sys.path.append('./cloob-latent-diffusion/cloob-training')
|
17 |
-
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
|
18 |
-
sys.path.append('./cloob-latent-diffusion/taming-transformers')
|
19 |
-
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
|
20 |
-
from omegaconf import OmegaConf
|
21 |
-
from PIL import Image
|
22 |
-
import torch
|
23 |
-
from torch import nn
|
24 |
-
from torch.nn import functional as F
|
25 |
-
from torchvision import transforms
|
26 |
-
from torchvision.transforms import functional as TF
|
27 |
-
from tqdm import trange
|
28 |
-
from CLIP import clip
|
29 |
-
from cloob_training import model_pt, pretrained
|
30 |
-
import ldm.models.autoencoder
|
31 |
-
from diffusion import sampling, utils
|
32 |
-
import train_latent_diffusion as train
|
33 |
-
from huggingface_hub import hf_hub_url, cached_download
|
34 |
-
import random
|
35 |
-
|
36 |
-
# Download the model files
|
37 |
-
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
|
38 |
-
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
39 |
-
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
40 |
-
|
41 |
-
# Define a few utility functions
|
42 |
-
|
43 |
-
def parse_prompt(prompt, default_weight=3.):
|
44 |
-
if prompt.startswith('http://') or prompt.startswith('https://'):
|
45 |
-
vals = prompt.rsplit(':', 2)
|
46 |
-
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
47 |
-
else:
|
48 |
-
vals = prompt.rsplit(':', 1)
|
49 |
-
vals = vals + ['', default_weight][len(vals):]
|
50 |
-
return vals[0], float(vals[1])
|
51 |
-
|
52 |
-
|
53 |
-
def resize_and_center_crop(image, size):
|
54 |
-
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
55 |
-
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
56 |
-
return TF.center_crop(image, size[::-1])
|
57 |
-
|
58 |
-
|
59 |
-
# Load the models
|
60 |
-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
61 |
-
print('Using device:', device)
|
62 |
-
print('loading models')
|
63 |
-
# autoencoder
|
64 |
-
ae_config = OmegaConf.load(ae_config_path)
|
65 |
-
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
|
66 |
-
ae_model.eval().requires_grad_(False).to(device)
|
67 |
-
ae_model.load_state_dict(torch.load(ae_model_path))
|
68 |
-
n_ch, side_y, side_x = 4, 32, 32
|
69 |
-
|
70 |
-
# diffusion model
|
71 |
-
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
72 |
-
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
73 |
-
model = model.to(device).eval().requires_grad_(False)
|
74 |
-
|
75 |
-
# CLOOB
|
76 |
-
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
77 |
-
cloob = model_pt.get_pt_model(cloob_config)
|
78 |
-
checkpoint = pretrained.download_checkpoint(cloob_config)
|
79 |
-
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
80 |
-
cloob.eval().requires_grad_(False).to(device)
|
81 |
-
|
82 |
-
|
83 |
-
# The key function: returns a list of n PIL images
|
84 |
-
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
|
85 |
-
method='plms', eta=None):
|
86 |
-
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
|
87 |
-
target_embeds, weights = [zero_embed], []
|
88 |
-
|
89 |
-
for prompt in prompts:
|
90 |
-
txt, weight = parse_prompt(prompt)
|
91 |
-
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
|
92 |
-
weights.append(weight)
|
93 |
-
|
94 |
-
for prompt in images:
|
95 |
-
path, weight = parse_prompt(prompt)
|
96 |
-
img = Image.open(utils.fetch(path)).convert('RGB')
|
97 |
-
clip_size = cloob.config['image_encoder']['image_size']
|
98 |
-
img = resize_and_center_crop(img, (clip_size, clip_size))
|
99 |
-
batch = TF.to_tensor(img)[None].to(device)
|
100 |
-
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
|
101 |
-
target_embeds.append(embed)
|
102 |
-
weights.append(weight)
|
103 |
-
|
104 |
-
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
105 |
-
|
106 |
-
torch.manual_seed(seed)
|
107 |
-
|
108 |
-
def cfg_model_fn(x, t):
|
109 |
-
n = x.shape[0]
|
110 |
-
n_conds = len(target_embeds)
|
111 |
-
x_in = x.repeat([n_conds, 1, 1, 1])
|
112 |
-
t_in = t.repeat([n_conds])
|
113 |
-
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
114 |
-
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
115 |
-
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
116 |
-
return v
|
117 |
-
|
118 |
-
def run(x, steps):
|
119 |
-
if method == 'ddpm':
|
120 |
-
return sampling.sample(cfg_model_fn, x, steps, 1., {})
|
121 |
-
if method == 'ddim':
|
122 |
-
return sampling.sample(cfg_model_fn, x, steps, eta, {})
|
123 |
-
if method == 'prk':
|
124 |
-
return sampling.prk_sample(cfg_model_fn, x, steps, {})
|
125 |
-
if method == 'plms':
|
126 |
-
return sampling.plms_sample(cfg_model_fn, x, steps, {})
|
127 |
-
if method == 'pie':
|
128 |
-
return sampling.pie_sample(cfg_model_fn, x, steps, {})
|
129 |
-
if method == 'plms2':
|
130 |
-
return sampling.plms2_sample(cfg_model_fn, x, steps, {})
|
131 |
-
assert False
|
132 |
-
|
133 |
-
batch_size = n
|
134 |
-
x = torch.randn([n, n_ch, side_y, side_x], device=device)
|
135 |
-
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
136 |
-
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
137 |
-
pil_ims = []
|
138 |
-
for i in trange(0, n, batch_size):
|
139 |
-
cur_batch_size = min(n - i, batch_size)
|
140 |
-
out_latents = run(x[i:i+cur_batch_size], steps)
|
141 |
-
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
|
142 |
-
for j, out in enumerate(outs):
|
143 |
-
pil_ims.append(utils.to_pil_image(out))
|
144 |
-
|
145 |
-
return pil_ims
|
146 |
-
|
147 |
-
|
148 |
-
import gradio as gr
|
149 |
-
|
150 |
-
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
151 |
-
if seed == None :
|
152 |
-
seed = random.randint(0, 10000)
|
153 |
-
print( prompt, im_prompt, seed, n_steps)
|
154 |
-
prompts = [prompt]
|
155 |
-
im_prompts = []
|
156 |
-
if im_prompt != None:
|
157 |
-
im_prompts = [im_prompt]
|
158 |
-
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
159 |
-
return pil_ims[0]
|
160 |
-
|
161 |
-
iface = gr.Interface(fn=gen_ims,
|
162 |
-
inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
|
163 |
-
#gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
|
164 |
-
gr.inputs.Textbox(label="Text prompt"),
|
165 |
-
gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
|
166 |
-
#gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
|
167 |
-
],
|
168 |
-
outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
|
169 |
-
examples=[["An iceberg, oil on canvas"],["A martian landscape, in the style of Monet"], ['A peaceful meadow, pastel crayons'], ["A painting of a vase of flowers"], ["A ship leaving the port in the summer, oil on canvas"]],
|
170 |
-
title='Generate art from text prompts :',
|
171 |
-
description="By typing a text prompt or providing an image prompt, and pressing submit you can generate images based on this prompt. The model was trained on images from the [WikiArt](https://huggingface.co/datasets/huggan/wikiart) dataset, comprised mostly of paintings.",
|
172 |
-
article = 'The model is a distilled version of a cloob-conditioned latent diffusion model fine-tuned on the WikiArt dataset. You can find more information on this model on the [model card](https://huggingface.co/huggan/distill-ccld-wa). The student model training and this demo were done by [@gigant](https://huggingface.co/gigant). The teacher model was trained by [@johnowhitaker](https://huggingface.co/johnowhitaker)'
|
173 |
-
|
174 |
-
)
|
175 |
-
iface.launch(enable_queue=True) # , debug=True for colab debugging
|
|
|
1 |
+
## make a streamlit app gui that says hello world
|
2 |
+
import streamlit as st
|
3 |
+
st.streamlit_version()
|
4 |
+
st.title("FLAMESTOPIA.AI SYSTEMS (C) 2022-20XX - Version 0.0.1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|