Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -7,9 +7,8 @@ subprocess.run(
|
|
7 |
shell=True,
|
8 |
)
|
9 |
|
10 |
-
from huggingface_hub import snapshot_download
|
11 |
-
|
12 |
os.makedirs("/home/user/app/checkpoints", exist_ok=True)
|
|
|
13 |
snapshot_download(
|
14 |
repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
|
15 |
)
|
@@ -32,8 +31,7 @@ import torch.distributed as dist
|
|
32 |
from torchvision.transforms.functional import to_pil_image
|
33 |
|
34 |
from PIL import Image
|
35 |
-
from
|
36 |
-
from threading import Thread, Barrier
|
37 |
|
38 |
import models
|
39 |
|
@@ -50,7 +48,6 @@ description = """
|
|
50 |
#### Demo current model: `Lumina-Next-T2I`
|
51 |
|
52 |
"""
|
53 |
-
|
54 |
hf_token = os.environ["HF_TOKEN"]
|
55 |
|
56 |
|
@@ -161,12 +158,11 @@ def load_models(args, master_port, rank):
|
|
161 |
assert train_args.model_parallel_size == args.num_gpus
|
162 |
if args.ema:
|
163 |
print("Loading ema model.")
|
164 |
-
ckpt =
|
165 |
os.path.join(
|
166 |
args.ckpt,
|
167 |
-
f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.
|
168 |
),
|
169 |
-
map_location="cpu",
|
170 |
)
|
171 |
model.load_state_dict(ckpt, strict=True)
|
172 |
|
@@ -179,17 +175,15 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
179 |
args.precision
|
180 |
]
|
181 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
182 |
-
|
183 |
-
print(args)
|
184 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
185 |
torch.cuda.set_device(0)
|
186 |
-
|
187 |
# loading model to gpu
|
188 |
# text_encoder = text_encoder.cuda()
|
189 |
# vae = vae.cuda()
|
190 |
# model = model.to("cuda", dtype=dtype)
|
191 |
|
192 |
-
with torch.autocast(
|
193 |
(
|
194 |
cap,
|
195 |
resolution,
|
@@ -202,18 +196,19 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
202 |
proportional_attn,
|
203 |
) = infer_args
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
proportional_attn,
|
216 |
)
|
|
|
|
|
217 |
try:
|
218 |
# begin sampler
|
219 |
transport = create_transport(
|
@@ -249,7 +244,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
249 |
latent_w, latent_h = w // 8, h // 8
|
250 |
if int(seed) != 0:
|
251 |
torch.random.manual_seed(int(seed))
|
252 |
-
z = torch.randn([1, 4, latent_h, latent_w], device=
|
253 |
z = z.repeat(2, 1, 1, 1)
|
254 |
|
255 |
with torch.no_grad():
|
@@ -276,13 +271,8 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
276 |
ntk_factor=ntk_factor,
|
277 |
)
|
278 |
|
279 |
-
print(
|
280 |
-
|
281 |
-
print(f"cfg_scale: {cfg_scale}")
|
282 |
-
|
283 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
284 |
-
print("> [debug] start sample")
|
285 |
-
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
286 |
samples = samples[:1]
|
287 |
|
288 |
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
|
@@ -294,7 +284,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
294 |
|
295 |
img = to_pil_image(samples[0].float())
|
296 |
|
297 |
-
return img
|
298 |
except Exception:
|
299 |
print(traceback.format_exc())
|
300 |
return ModelFailure()
|
@@ -505,18 +495,15 @@ def main():
|
|
505 |
)
|
506 |
with gr.Row():
|
507 |
submit_btn = gr.Button("Submit", variant="primary")
|
508 |
-
# reset_btn = gr.ClearButton([
|
509 |
-
# cap, resolution,
|
510 |
-
# num_sampling_steps, cfg_scale, solver,
|
511 |
-
# t_shift, seed,
|
512 |
-
# ntk_scaling, proportional_attn
|
513 |
-
# ])
|
514 |
with gr.Column():
|
515 |
output_img = gr.Image(
|
516 |
label="Lumina Generated image",
|
517 |
interactive=False,
|
518 |
format="png",
|
|
|
519 |
)
|
|
|
|
|
520 |
|
521 |
with gr.Row():
|
522 |
gr.Examples(
|
@@ -582,8 +569,8 @@ def main():
|
|
582 |
examples_per_page=22,
|
583 |
)
|
584 |
|
585 |
-
@spaces.GPU(duration=
|
586 |
-
def on_submit(*infer_args):
|
587 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
588 |
if isinstance(result, ModelFailure):
|
589 |
raise RuntimeError("Model failed to generate the image.")
|
@@ -602,10 +589,10 @@ def main():
|
|
602 |
ntk_scaling,
|
603 |
proportional_attn,
|
604 |
],
|
605 |
-
[output_img],
|
606 |
)
|
607 |
|
608 |
-
demo.queue(
|
609 |
|
610 |
|
611 |
if __name__ == "__main__":
|
|
|
7 |
shell=True,
|
8 |
)
|
9 |
|
|
|
|
|
10 |
os.makedirs("/home/user/app/checkpoints", exist_ok=True)
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
snapshot_download(
|
13 |
repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
|
14 |
)
|
|
|
31 |
from torchvision.transforms.functional import to_pil_image
|
32 |
|
33 |
from PIL import Image
|
34 |
+
from safetensors.torch import load_file
|
|
|
35 |
|
36 |
import models
|
37 |
|
|
|
48 |
#### Demo current model: `Lumina-Next-T2I`
|
49 |
|
50 |
"""
|
|
|
51 |
hf_token = os.environ["HF_TOKEN"]
|
52 |
|
53 |
|
|
|
158 |
assert train_args.model_parallel_size == args.num_gpus
|
159 |
if args.ema:
|
160 |
print("Loading ema model.")
|
161 |
+
ckpt = load_file(
|
162 |
os.path.join(
|
163 |
args.ckpt,
|
164 |
+
f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
|
165 |
),
|
|
|
166 |
)
|
167 |
model.load_state_dict(ckpt, strict=True)
|
168 |
|
|
|
175 |
args.precision
|
176 |
]
|
177 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
178 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
179 |
torch.cuda.set_device(0)
|
180 |
+
|
181 |
# loading model to gpu
|
182 |
# text_encoder = text_encoder.cuda()
|
183 |
# vae = vae.cuda()
|
184 |
# model = model.to("cuda", dtype=dtype)
|
185 |
|
186 |
+
with torch.autocast(device, dtype):
|
187 |
(
|
188 |
cap,
|
189 |
resolution,
|
|
|
196 |
proportional_attn,
|
197 |
) = infer_args
|
198 |
|
199 |
+
metadata = dict(
|
200 |
+
cap=cap,
|
201 |
+
resolution=resolution,
|
202 |
+
num_sampling_steps=num_sampling_steps,
|
203 |
+
cfg_scale=cfg_scale,
|
204 |
+
solver=solver,
|
205 |
+
t_shift=t_shift,
|
206 |
+
seed=seed,
|
207 |
+
ntk_scaling=ntk_scaling,
|
208 |
+
proportional_attn=proportional_attn,
|
|
|
209 |
)
|
210 |
+
print("> params:", json.dumps(metadata, indent=2))
|
211 |
+
|
212 |
try:
|
213 |
# begin sampler
|
214 |
transport = create_transport(
|
|
|
244 |
latent_w, latent_h = w // 8, h // 8
|
245 |
if int(seed) != 0:
|
246 |
torch.random.manual_seed(int(seed))
|
247 |
+
z = torch.randn([1, 4, latent_h, latent_w], device=device).to(dtype)
|
248 |
z = z.repeat(2, 1, 1, 1)
|
249 |
|
250 |
with torch.no_grad():
|
|
|
271 |
ntk_factor=ntk_factor,
|
272 |
)
|
273 |
|
274 |
+
print("> start sample")
|
275 |
+
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
|
|
|
|
|
|
|
|
|
|
276 |
samples = samples[:1]
|
277 |
|
278 |
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
|
|
|
284 |
|
285 |
img = to_pil_image(samples[0].float())
|
286 |
|
287 |
+
return img, metadata
|
288 |
except Exception:
|
289 |
print(traceback.format_exc())
|
290 |
return ModelFailure()
|
|
|
495 |
)
|
496 |
with gr.Row():
|
497 |
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
with gr.Column():
|
499 |
output_img = gr.Image(
|
500 |
label="Lumina Generated image",
|
501 |
interactive=False,
|
502 |
format="png",
|
503 |
+
show_label=False
|
504 |
)
|
505 |
+
with gr.Accordion(label="Generation Parameters", open=False):
|
506 |
+
gr_metadata = gr.JSON(label="metadata", show_label=False)
|
507 |
|
508 |
with gr.Row():
|
509 |
gr.Examples(
|
|
|
569 |
examples_per_page=22,
|
570 |
)
|
571 |
|
572 |
+
@spaces.GPU(duration=80)
|
573 |
+
def on_submit(*infer_args, progress=gr.Progress(track_tqdm=True),):
|
574 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
575 |
if isinstance(result, ModelFailure):
|
576 |
raise RuntimeError("Model failed to generate the image.")
|
|
|
589 |
ntk_scaling,
|
590 |
proportional_attn,
|
591 |
],
|
592 |
+
[output_img, gr_metadata],
|
593 |
)
|
594 |
|
595 |
+
demo.queue().launch()
|
596 |
|
597 |
|
598 |
if __name__ == "__main__":
|