LouisLi commited on
Commit
2757a07
·
verified ·
1 Parent(s): 1c15d66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -154
app.py CHANGED
@@ -34,13 +34,19 @@ import asyncio
34
 
35
 
36
 
37
- import uuid
38
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
39
- from diffusers.utils import export_to_video
40
- from safetensors.torch import load_file
41
  #from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
 
43
- # import spaces #
 
 
 
 
 
 
44
 
45
 
46
  import imageio
@@ -292,82 +298,106 @@ def make3d(images):
292
  ###############################################################################
293
 
294
  ###############################################################################
295
- ############# this part is for text to video #############
296
  ###############################################################################
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- MORE = """ ## TRY Other Models
302
- ### JARVIS: Your VOICE Assistant -> https://huggingface.co/spaces/KingNish/JARVIS
303
- ### Instant Image: 4k images in 5 Second -> https://huggingface.co/spaces/KingNish/Instant-Image
304
- """
305
-
306
- # Constants
307
- bases = {
308
- "Cartoon": "frankjoshua/toonyou_beta6",
309
- "Realistic": "emilianJR/epiCRealism",
310
- "3d": "Lykon/DreamShaper",
311
- "Anime": "Yntec/mistoonAnime2"
312
- }
313
- step_loaded = None
314
- base_loaded = "Realistic"
315
- motion_loaded = None
316
-
317
- # Ensure model and scheduler are initialized in GPU-enabled function
318
- if not torch.cuda.is_available():
319
- raise NotImplementedError("No GPU detected!")
320
-
321
- device = "cuda"
322
- dtype = torch.float16
323
- pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
324
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
325
-
326
- # Safety checkers
327
- from transformers import CLIPFeatureExtractor
328
-
329
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
330
-
331
- # Function
332
- #@spaces.GPU(duration=60,queue=False)
333
- def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
334
- global step_loaded
335
- global base_loaded
336
- global motion_loaded
337
- print(prompt, base, step)
338
-
339
- if step_loaded != step:
340
- repo = "ByteDance/AnimateDiff-Lightning"
341
- ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
342
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
343
- step_loaded = step
344
-
345
- if base_loaded != base:
346
- pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
347
- base_loaded = base
348
-
349
- if motion_loaded != motion:
350
- pipe.unload_lora_weights()
351
- if motion != "":
352
- pipe.load_lora_weights(motion, adapter_name="motion")
353
- pipe.set_adapters(["motion"], [0.7])
354
- motion_loaded = motion
355
-
356
- progress((0, step))
357
- def progress_callback(i, t, z):
358
- progress((i+1, step))
359
-
360
- output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)
361
-
362
- name = str(uuid.uuid4()).replace("-", "")
363
- path = f"/tmp/{name}.mp4"
364
- export_to_video(output.frames[0], path, fps=10)
365
- return path
366
-
367
 
368
 
369
  ###############################################################################
370
- ############# above part is for text to video #############
371
  ###############################################################################
372
 
373
 
@@ -1353,99 +1383,116 @@ def create_ui():
1353
  ###############################################################################
1354
 
1355
  ###############################################################################
1356
- ############# this part is for text to video #############
1357
  ###############################################################################
1358
 
1359
- with gr.Row(variant="panel") as text2video_model:
1360
- with gr.Column():
1361
- with gr.Row():
1362
- prompt = gr.Textbox(
1363
- label='Prompt'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1364
  )
1365
  with gr.Row():
1366
- select_base = gr.Dropdown(
1367
- label='Base model',
1368
- choices=[
1369
- "Cartoon",
1370
- "Realistic",
1371
- "3d",
1372
- "Anime",
1373
- ],
1374
- value=base_loaded,
1375
- interactive=True
1376
- )
1377
- select_motion = gr.Dropdown(
1378
- label='Motion',
1379
- choices=[
1380
- ("Default", ""),
1381
- ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
1382
- ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
1383
- ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
1384
- ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
1385
- ("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
1386
- ("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
1387
- ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
1388
- ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
1389
- ],
1390
- value="guoyww/animatediff-motion-lora-zoom-in",
1391
- interactive=True
1392
- )
1393
- select_step = gr.Dropdown(
1394
- label='Inference steps',
1395
- choices=[
1396
- ('1-Step', 1),
1397
- ('2-Step', 2),
1398
- ('4-Step', 4),
1399
- ('8-Step', 8),
1400
- ],
1401
- value=4,
1402
- interactive=True
1403
- )
1404
- submit = gr.Button(
1405
- scale=1,
1406
- variant='primary'
1407
  )
1408
- with gr.Column():
1409
- with gr.Row():
1410
- video = gr.Video(
1411
- label='AnimateDiff-Lightning',
1412
- autoplay=True,
1413
- height=512,
1414
- width=512,
1415
- elem_id="video_output"
1416
  )
1417
-
1418
- prompt.submit(
1419
- fn=generate_image,
1420
- inputs=[prompt, select_base, select_motion, select_step],
1421
- outputs=video,
1422
  )
1423
- submit.click(
1424
- fn=generate_image,
1425
- inputs=[prompt, select_base, select_motion, select_step],
1426
- outputs=video,
 
 
1427
  )
1428
 
1429
- gr.Examples(
1430
- examples=[
1431
- ["Focus: Eiffel Tower (Animate: Clouds moving)"], #Atmosphere Movement Example
1432
- ["Focus: Trees In forest (Animate: Lion running)"], #Object Movement Example
1433
- ["Focus: Astronaut in Space"], #Normal
1434
- ["Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)"], #Camera distance
1435
- ["Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)"], #Camera Movement
1436
- ["Focus: Panda in Forest (Animate: Drinking Tea)"], #Doing Something
1437
- ["Focus: Kids Playing (Season: Winter)"], #Atmosphere or Season
1438
- {"Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"} #Mixture
1439
- ],
1440
- fn=generate_image,
1441
- inputs=[prompt],
1442
- outputs=video,
1443
- cache_examples=True,
1444
- )
 
 
 
 
 
 
1445
 
1446
 
1447
  ###############################################################################
1448
- ############# above part is for text to video #############
1449
  ###############################################################################
1450
  def clear_tts_fields():
1451
  return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
 
34
 
35
 
36
 
37
+ # import uuid
38
+ # from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
39
+ # from diffusers.utils import export_to_video
40
+ # from safetensors.torch import load_file
41
  #from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
 
43
+
44
+ import random
45
+ import uuid
46
+ import json
47
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
48
+
49
+
50
 
51
 
52
  import imageio
 
298
  ###############################################################################
299
 
300
  ###############################################################################
301
+ ############# this part is for text to image #############
302
  ###############################################################################
303
 
304
+ # Use environment variables for flexibility
305
+ MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
306
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
307
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
308
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
309
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
310
+
311
+ # Determine device and load model outside of function for efficiency
312
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
313
+ pipe = StableDiffusionXLPipeline.from_pretrained(
314
+ MODEL_ID,
315
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
316
+ use_safetensors=True,
317
+ add_watermarker=False,
318
+ ).to(device)
319
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
320
+
321
+ # Torch compile for potential speedup (experimental)
322
+ if USE_TORCH_COMPILE:
323
+ pipe.compile()
324
+
325
+ # CPU offloading for larger RAM capacity (experimental)
326
+ if ENABLE_CPU_OFFLOAD:
327
+ pipe.enable_model_cpu_offload()
328
+
329
+ MAX_SEED = np.iinfo(np.int32).max
330
+
331
+ def save_image(img):
332
+ unique_name = str(uuid.uuid4()) + ".png"
333
+ img.save(unique_name)
334
+ return unique_name
335
+
336
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
337
+ if randomize_seed:
338
+ seed = random.randint(0, MAX_SEED)
339
+ return seed
340
+
341
+ # @spaces.GPU(duration=30, queue=False)
342
+ def generate(
343
+ prompt: str,
344
+ negative_prompt: str = "",
345
+ use_negative_prompt: bool = False,
346
+ seed: int = 1,
347
+ width: int = 1024,
348
+ height: int = 1024,
349
+ guidance_scale: float = 3,
350
+ num_inference_steps: int = 30,
351
+ randomize_seed: bool = False,
352
+ use_resolution_binning: bool = True,
353
+ num_images: int = 1, # Number of images to generate
354
+ progress=gr.Progress(track_tqdm=True),
355
+ ):
356
+ seed = int(randomize_seed_fn(seed, randomize_seed))
357
+ generator = torch.Generator(device=device).manual_seed(seed)
358
+
359
+ # Improved options handling
360
+ options = {
361
+ "prompt": [prompt] * num_images,
362
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
363
+ "width": width,
364
+ "height": height,
365
+ "guidance_scale": guidance_scale,
366
+ "num_inference_steps": num_inference_steps,
367
+ "generator": generator,
368
+ "output_type": "pil",
369
+ }
370
 
371
+ # Use resolution binning for faster generation with less VRAM usage
372
+ if use_resolution_binning:
373
+ options["use_resolution_binning"] = True
374
+
375
+ # Generate images potentially in batches
376
+ images = []
377
+ for i in range(0, num_images, BATCH_SIZE):
378
+ batch_options = options.copy()
379
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
380
+ if "negative_prompt" in batch_options:
381
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
382
+ images.extend(pipe(**batch_options).images)
383
+
384
+ image_paths = [save_image(img) for img in images]
385
+ return image_paths, seed
386
+
387
+ examples = [
388
+ "a cat eating a piece of cheese",
389
+ "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
390
+ "Ironman VS Hulk, ultrarealistic",
391
+ "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
392
+ "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
393
+ "Kids going to school, Anime style"
394
+ ]
395
 
396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
 
399
  ###############################################################################
400
+ ############# above part is for text to image #############
401
  ###############################################################################
402
 
403
 
 
1383
  ###############################################################################
1384
 
1385
  ###############################################################################
1386
+ ############# this part is for text to image #############
1387
  ###############################################################################
1388
 
1389
+ with gr.Row(variant="panel") as text2image_model:
1390
+ with gr.Row():
1391
+ prompt = gr.Text(
1392
+ label="Prompt",
1393
+ show_label=False,
1394
+ max_lines=1,
1395
+ placeholder="Enter your prompt",
1396
+ container=False,
1397
+ )
1398
+ run_button = gr.Button("Run", scale=0)
1399
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
1400
+ with gr.Accordion("Advanced options", open=False):
1401
+ num_images = gr.Slider(
1402
+ label="Number of Images",
1403
+ minimum=1,
1404
+ maximum=4,
1405
+ step=1,
1406
+ value=1,
1407
+ )
1408
+ with gr.Row():
1409
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
1410
+ negative_prompt = gr.Text(
1411
+ label="Negative prompt",
1412
+ max_lines=5,
1413
+ lines=4,
1414
+ placeholder="Enter a negative prompt",
1415
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
1416
+ visible=True,
1417
+ )
1418
+ seed = gr.Slider(
1419
+ label="Seed",
1420
+ minimum=0,
1421
+ maximum=MAX_SEED,
1422
+ step=1,
1423
+ value=0,
1424
+ )
1425
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
1426
+ with gr.Row(visible=True):
1427
+ width = gr.Slider(
1428
+ label="Width",
1429
+ minimum=512,
1430
+ maximum=MAX_IMAGE_SIZE,
1431
+ step=64,
1432
+ value=1024,
1433
+ )
1434
+ height = gr.Slider(
1435
+ label="Height",
1436
+ minimum=512,
1437
+ maximum=MAX_IMAGE_SIZE,
1438
+ step=64,
1439
+ value=1024,
1440
  )
1441
  with gr.Row():
1442
+ guidance_scale = gr.Slider(
1443
+ label="Guidance Scale",
1444
+ minimum=0.1,
1445
+ maximum=6,
1446
+ step=0.1,
1447
+ value=3.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1448
  )
1449
+ num_inference_steps = gr.Slider(
1450
+ label="Number of inference steps",
1451
+ minimum=1,
1452
+ maximum=15,
1453
+ step=1,
1454
+ value=8,
 
 
1455
  )
1456
+
1457
+ gr.Examples(
1458
+ examples=examples,
1459
+ inputs=prompt,
1460
+ cache_examples=False
1461
  )
1462
+
1463
+ use_negative_prompt.change(
1464
+ fn=lambda x: gr.update(visible=x),
1465
+ inputs=use_negative_prompt,
1466
+ outputs=negative_prompt,
1467
+ api_name=False,
1468
  )
1469
 
1470
+ gr.on(
1471
+ triggers=[
1472
+ prompt.submit,
1473
+ negative_prompt.submit,
1474
+ run_button.click,
1475
+ ],
1476
+ fn=generate,
1477
+ inputs=[
1478
+ prompt,
1479
+ negative_prompt,
1480
+ use_negative_prompt,
1481
+ seed,
1482
+ width,
1483
+ height,
1484
+ guidance_scale,
1485
+ num_inference_steps,
1486
+ randomize_seed,
1487
+ num_images
1488
+ ],
1489
+ outputs=[result, seed],
1490
+ api_name="run",
1491
+ )
1492
 
1493
 
1494
  ###############################################################################
1495
+ ############# above part is for text to image #############
1496
  ###############################################################################
1497
  def clear_tts_fields():
1498
  return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]