SimianLuo patrickvonplaten commited on
Commit
c144567
1 Parent(s): 1128191

[Don't merge yet] Simplify inference (#1)

Browse files

- [Don't merge yet] Simplify inference (3ddb077c18b666e1bba594087dc887b87824472c)


Co-authored-by: Patrick von Platen <[email protected]>

Files changed (1) hide show
  1. app.py +4 -44
app.py CHANGED
@@ -9,12 +9,9 @@ import gradio as gr
9
  import numpy as np
10
  import PIL.Image
11
  import torch
12
- from lcm_pipeline import LatentConsistencyModelPipeline
13
- from lcm_scheduler import LCMScheduler
14
 
15
- from diffusers import AutoencoderKL, UNet2DConditionModel
16
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
- from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
18
 
19
  import os
20
  import torch
@@ -34,45 +31,8 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
34
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
35
  DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
36
 
37
- model_id = "digiplay/DreamShaper_7"
38
-
39
-
40
- # Initalize Diffusers Model:
41
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
42
- text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
43
- tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
44
- config = UNet2DConditionModel.load_config(model_id, subfolder="unet")
45
- config["time_cond_proj_dim"] = 256
46
-
47
- unet = UNet2DConditionModel.from_config(config)
48
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_id, subfolder="safety_checker")
49
- feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
50
-
51
- # Initalize Scheduler:
52
- scheduler = LCMScheduler(beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon")
53
-
54
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
55
-
56
- if torch.cuda.is_available():
57
- # Replace the unet with LCM:
58
- # lcm_unet_ckpt = hf_hub_download("SimianLuo/LCM_Dreamshaper_v7", filename="LCM_Dreamshaper_v7_4k.safetensors", token=HF_TOKEN)
59
- lcm_unet_ckpt = "./LCM_Dreamshaper_v7_4k.safetensors"
60
- ckpt = load_file(lcm_unet_ckpt)
61
- m, u = unet.load_state_dict(ckpt, strict=False)
62
- if len(m) > 0:
63
- print("missing keys:")
64
- print(m)
65
- if len(u) > 0:
66
- print("unexpected keys:")
67
- print(u)
68
-
69
-
70
- # LCM Pipeline:
71
- pipe = LatentConsistencyModelPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor)
72
- pipe = pipe.to(torch_device="cuda", torch_dtype=DTYPE)
73
-
74
- if USE_TORCH_COMPILE:
75
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
76
 
77
 
78
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
 
9
  import numpy as np
10
  import PIL.Image
11
  import torch
 
 
12
 
13
+ from diffusers import DiffusionPipeline
14
+ import torch
 
15
 
16
  import os
17
  import torch
 
31
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
32
  DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
33
 
34
+ pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_txt2img")
35
+ pipe.to(torch_device="cuda", torch_dtype=DTYPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: