apolinario commited on
Commit
c4c0d2b
·
1 Parent(s): 1a9130d

Initial application

Browse files
Files changed (1) hide show
  1. app.py +137 -8
app.py CHANGED
@@ -1,11 +1,140 @@
 
1
  import gradio as gr
2
- import torch
 
 
 
 
 
 
 
3
 
4
- is_cuda = torch.cuda.is_available()
5
- def greet(name):
6
- if is_cuda:
7
- return "Hello cuda" + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  else:
9
- return "Hello ooops" + name + "!!"
10
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
11
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydoc import describe
2
  import gradio as gr
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+ import sys
6
+ sys.path.append(".")
7
+ sys.path.append('./taming-transformers')
8
+ sys.path.append('./latent-diffusion')
9
+ from taming.models import vqgan
10
+ from ldm.util import instantiate_from_config
11
 
12
+ torch.hub.download_url_to_file('https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt','txt2img-f8-large.ckpt')
13
+
14
+ #@title Import stuff
15
+ import argparse, os, sys, glob
16
+ import numpy as np
17
+ from PIL import Image
18
+ from einops import rearrange
19
+ from torchvision.utils import make_grid
20
+ import transformers
21
+ import gc
22
+ from ldm.util import instantiate_from_config
23
+ from ldm.models.diffusion.ddim import DDIMSampler
24
+ from ldm.models.diffusion.plms import PLMSSampler
25
+
26
+
27
+ def load_model_from_config(config, ckpt, verbose=False):
28
+ print(f"Loading model from {ckpt}")
29
+ pl_sd = torch.load(ckpt, map_location="cuda:0")
30
+ sd = pl_sd["state_dict"]
31
+ model = instantiate_from_config(config.model)
32
+ m, u = model.load_state_dict(sd, strict=False)
33
+ if len(m) > 0 and verbose:
34
+ print("missing keys:")
35
+ print(m)
36
+ if len(u) > 0 and verbose:
37
+ print("unexpected keys:")
38
+ print(u)
39
+
40
+ model = model.half().cuda()
41
+ model.eval()
42
+ return model
43
+
44
+ config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
45
+ model = load_model_from_config(config, f"latent_diffusion_txt2img_f8_large.ckpt") # TODO: check path
46
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
+ model = model.to(device)
48
+
49
+ def run(prompt, steps, width, height, images, scale, eta):
50
+ if images == 6:
51
+ images = 3
52
+ n_iter = 2
53
+ else:
54
+ n_iter = 1
55
+ opt = argparse.Namespace(
56
+ prompt = prompt,
57
+ outdir='latent-diffusion/outputs',
58
+ ddim_steps = int(steps),
59
+ ddim_eta = eta,
60
+ n_iter = n_iter,
61
+ W=int(width),
62
+ H=int(height),
63
+ n_samples=int(images),
64
+ scale=scale,
65
+ plms=True
66
+ )
67
+ if opt.plms:
68
+ opt.ddim_eta = 0
69
+ sampler = PLMSSampler(model)
70
  else:
71
+ sampler = DDIMSampler(model)
72
+
73
+ os.makedirs(opt.outdir, exist_ok=True)
74
+ outpath = opt.outdir
75
+
76
+ prompt = opt.prompt
77
+
78
+
79
+ sample_path = os.path.join(outpath, "samples")
80
+ os.makedirs(sample_path, exist_ok=True)
81
+ base_count = len(os.listdir(sample_path))
82
+
83
+ all_samples=list()
84
+ all_samples_images=list()
85
+ with torch.no_grad():
86
+ with torch.cuda.amp.autocast():
87
+ with model.ema_scope():
88
+ uc = None
89
+ if opt.scale > 0:
90
+ uc = model.get_learned_conditioning(opt.n_samples * [""])
91
+ for n in range(opt.n_iter):
92
+ c = model.get_learned_conditioning(opt.n_samples * [prompt])
93
+ shape = [4, opt.H//8, opt.W//8]
94
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
95
+ conditioning=c,
96
+ batch_size=opt.n_samples,
97
+ shape=shape,
98
+ verbose=False,
99
+ unconditional_guidance_scale=opt.scale,
100
+ unconditional_conditioning=uc,
101
+ eta=opt.ddim_eta)
102
+
103
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
104
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
105
+
106
+ for x_sample in x_samples_ddim:
107
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
108
+ all_samples_images.append(Image.fromarray(x_sample.astype(np.uint8)))
109
+ #Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
110
+ base_count += 1
111
+ all_samples.append(x_samples_ddim)
112
+
113
+
114
+ # additionally, save as grid
115
+ grid = torch.stack(all_samples, 0)
116
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
117
+ grid = make_grid(grid, nrow=2)
118
+ # to image
119
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
120
+
121
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))
122
+ return(Image.fromarray(grid.astype(np.uint8)),all_samples_images)
123
+
124
+ image = gr.outputs.Image(type="pil", label="Your result")
125
+ css = ".output-image{height: 528px !important} .output-carousel .output-image{height:272px !important}"
126
+ iface = gr.Interface(fn=run, inputs=[
127
+ gr.inputs.Textbox(label="Prompt",default="A drawing of a cute dog with a funny hat"),
128
+ gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=250,minimum=1,step=1),
129
+ gr.inputs.Slider(label="Width", minimum=64, maximum=256, default=256, step=64),
130
+ gr.inputs.Slider(label="Height", minimum=64, maximum=256, default=256, step=64),
131
+ gr.inputs.Slider(label="Images - How many images you wish to generate", default=4, step=2, minimum=2, maximum=6),
132
+ gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=5.0, minimum=1),
133
+ gr.inputs.Slider(label="ETA - between 0 and 1. Lower values can provide better quality, higher values can be more diverse",default=0.0,minimum=0.0, maximum=1.0,step=0.1),
134
+
135
+ ],
136
+ outputs=[image,gr.outputs.Carousel(label="Individual images",components=["image"])],
137
+ css=css,
138
+ title="Generate images from text with Latent Diffusion LAION-400M",
139
+ description="<div>By typing a text and clicking submit you can generate images based on this text. This is a text-to-image model created by CompVis, trained on the LAION-400M dataset.<br>For more multimodal ai art check us out <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>")
140
+ iface.launch(enable_queue=True)