denisp1 commited on
Commit
5c87bc4
·
1 Parent(s): 10507e1

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -201
app.py DELETED
@@ -1,201 +0,0 @@
1
- import os
2
-
3
- os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
4
- os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
5
-
6
- import argparse
7
- from functools import partial
8
- from pathlib import Path
9
- import sys
10
- sys.path.append('./cloob-latent-diffusion')
11
- sys.path.append('./cloob-latent-diffusion/cloob-training')
12
- sys.path.append('./cloob-latent-diffusion/latent-diffusion')
13
- sys.path.append('./cloob-latent-diffusion/taming-transformers')
14
- sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
15
- from omegaconf import OmegaConf
16
- from PIL import Image
17
- import torch
18
- from torch import nn
19
- from torch.nn import functional as F
20
- from torchvision import transforms
21
- from torchvision.transforms import functional as TF
22
- from tqdm import trange
23
- from CLIP import clip
24
- from cloob_training import model_pt, pretrained
25
- import ldm.models.autoencoder
26
- from diffusion import sampling, utils
27
- import train_latent_diffusion as train
28
- from huggingface_hub import hf_hub_url, cached_download
29
- import random
30
-
31
- # Download the model files
32
- checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
33
- ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
34
- ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
35
-
36
- # Define a few utility functions
37
-
38
- def parse_prompt(prompt, default_weight=3.):
39
- if prompt.startswith('http://') or prompt.startswith('https://'):
40
- vals = prompt.rsplit(':', 2)
41
- vals = [vals[0] + ':' + vals[1], *vals[2:]]
42
- else:
43
- vals = prompt.rsplit(':', 1)
44
- vals = vals + ['', default_weight][len(vals):]
45
- return vals[0], float(vals[1])
46
-
47
-
48
- def resize_and_center_crop(image, size):
49
- fac = max(size[0] / image.size[0], size[1] / image.size[1])
50
- image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
51
- return TF.center_crop(image, size[::-1])
52
-
53
-
54
- # Load the models
55
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
56
- print('Using device:', device)
57
- print('loading models')
58
-
59
- # autoencoder
60
- ae_config = OmegaConf.load(ae_config_path)
61
- ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
62
- ae_model.eval().requires_grad_(False).to(device)
63
- ae_model.load_state_dict(torch.load(ae_model_path))
64
- n_ch, side_y, side_x = 4, 32, 32
65
-
66
- # diffusion model
67
- model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
68
- model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
69
- model = model.to(device).eval().requires_grad_(False)
70
-
71
- # CLOOB
72
- cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
73
- cloob = model_pt.get_pt_model(cloob_config)
74
- checkpoint = pretrained.download_checkpoint(cloob_config)
75
- cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
76
- cloob.eval().requires_grad_(False).to(device)
77
-
78
-
79
- # The key function: returns a list of n PIL images
80
- def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
81
- method='plms', eta=None):
82
- zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
83
- target_embeds, weights = [zero_embed], []
84
-
85
- for prompt in prompts:
86
- txt, weight = parse_prompt(prompt)
87
- target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
88
- weights.append(weight)
89
-
90
- for prompt in images:
91
- path, weight = parse_prompt(prompt)
92
- img = Image.open(utils.fetch(path)).convert('RGB')
93
- clip_size = cloob.config['image_encoder']['image_size']
94
- img = resize_and_center_crop(img, (clip_size, clip_size))
95
- batch = TF.to_tensor(img)[None].to(device)
96
- embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
97
- target_embeds.append(embed)
98
- weights.append(weight)
99
-
100
- weights = torch.tensor([1 - sum(weights), *weights], device=device)
101
-
102
- torch.manual_seed(seed)
103
-
104
- def cfg_model_fn(x, t):
105
- n = x.shape[0]
106
- n_conds = len(target_embeds)
107
- x_in = x.repeat([n_conds, 1, 1, 1])
108
- t_in = t.repeat([n_conds])
109
- clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
110
- vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
111
- v = vs.mul(weights[:, None, None, None, None]).sum(0)
112
- return v
113
-
114
- def run(x, steps):
115
- if method == 'ddpm':
116
- return sampling.sample(cfg_model_fn, x, steps, 1., {})
117
- if method == 'ddim':
118
- return sampling.sample(cfg_model_fn, x, steps, eta, {})
119
- if method == 'prk':
120
- return sampling.prk_sample(cfg_model_fn, x, steps, {})
121
- if method == 'plms':
122
- return sampling.plms_sample(cfg_model_fn, x, steps, {})
123
- if method == 'pie':
124
- return sampling.pie_sample(cfg_model_fn, x, steps, {})
125
- if method == 'plms2':
126
- return sampling.plms2_sample(cfg_model_fn, x, steps, {})
127
- assert False
128
-
129
- batch_size = n
130
- x = torch.randn([n, n_ch, side_y, side_x], device=device)
131
- t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
132
- steps = utils.get_spliced_ddpm_cosine_schedule(t)
133
- pil_ims = []
134
- for i in trange(0, n, batch_size):
135
- cur_batch_size = min(n - i, batch_size)
136
- out_latents = run(x[i:i+cur_batch_size], steps)
137
- outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
138
- for j, out in enumerate(outs):
139
- pil_ims.append(utils.to_pil_image(out))
140
-
141
- return pil_ims
142
-
143
-
144
- import gradio as gr
145
-
146
- def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
147
- if seed == None :
148
- seed = random.randint(0, 10000)
149
- print( prompt, im_prompt, seed, n_steps)
150
- prompts = [prompt]
151
- im_prompts = []
152
- if im_prompt != None:
153
- im_prompts = [im_prompt]
154
- pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
155
- return pil_ims[0]
156
-
157
- iface = gr.Interface(fn=gen_ims,
158
- inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
159
- #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
160
- gr.inputs.Textbox(label="Text prompt"),
161
- gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
162
- #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
163
- ],
164
- outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
165
- examples=[
166
- ["Impressionism, oil on canvas"],
167
- ["Futurism, in the style of Wassily Kandinsky"],
168
- ["Art Nouveau, in the style of John Singer Sargent"],
169
- ["Surrealism, in the style of Edgar Degas"],
170
- ["Expressionism, in the style of Wassily Kandinsky"],
171
- ["Futurism, in the style of Egon Schiele"],
172
- ["Neoclassicism, in the style of Gustav Klimt"],
173
- ["Cubism, in the style of Gustav Klimt"],
174
- ["Op Art, in the style of Marc Chagall"],
175
- ["Romanticism, in the style of M.C. Escher"],
176
- ["Futurism, in the style of M.C. Escher"],
177
- ["Abstract Art, in the style of M.C. Escher"],
178
- ["Mannerism, in the style of Paul Klee"],
179
- ["Romanesque Art, in the style of Leonardo da Vinci"],
180
- ["High Renaissance, in the style of Rembrandt"],
181
- ["Magic Realism, in the style of Gustave Dore"],
182
- ["Realism, in the style of Jean-Michel Basquiat"],
183
- ["Art Nouveau, in the style of Paul Gauguin"],
184
- ["Avant-garde, in the style of Pierre-Auguste Renoir"],
185
- ["Baroque, in the style of Edward Hopper"],
186
- ["Post-Impressionism, in the style of Wassily Kandinsky"],
187
- ["Naturalism, in the style of Rene Magritte"],
188
- ["Constructivism, in the style of Paul Cezanne"],
189
- ["Abstract Expressionism, in the style of Henri Matisse"],
190
- ["Pop Art, in the style of Vincent van Gogh"],
191
- ["Futurism, in the style of Wassily Kandinsky"],
192
- ["Futurism, in the style of Zdzislaw Beksinski"],
193
- ['Surrealism, in the style of Salvador Dali'],
194
- ["Aaron Wacker, oil on canvas"]
195
- ],
196
- title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
197
- description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
198
- article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'
199
-
200
- )
201
- iface.launch(enable_queue=True) # , debug=True for colab debugging