multimodalart HF staff commited on
Commit
aa5a24b
1 Parent(s): 61bc6a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -7
app.py CHANGED
@@ -5,7 +5,11 @@ from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
 
7
  ### SDXL Turbo ####
8
- pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
 
 
 
 
9
  pipe_turbo.to("cuda")
10
 
11
  ### SDXL Lightning ###
@@ -13,9 +17,19 @@ base = "stabilityai/stable-diffusion-xl-base-1.0"
13
  repo = "ByteDance/SDXL-Lightning"
14
  ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
15
 
16
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
17
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
18
- pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
 
 
 
 
 
 
 
 
 
19
  pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
20
  pipe_lightning.to("cuda")
21
 
@@ -23,11 +37,21 @@ pipe_lightning.to("cuda")
23
  repo_name = "ByteDance/Hyper-SD"
24
  ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
25
 
26
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
27
- unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name), device="cuda"))
28
- pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
 
 
 
 
 
 
 
 
29
  pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
30
  pipe_hyper.to("cuda")
 
31
 
32
  def run_comparison(prompt):
33
  image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
 
5
  from safetensors.torch import load_file
6
 
7
  ### SDXL Turbo ####
8
+ pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
9
+ vae=vae,
10
+ torch_dtype=torch.float16,
11
+ variant="fp16"
12
+ )
13
  pipe_turbo.to("cuda")
14
 
15
  ### SDXL Lightning ###
 
17
  repo = "ByteDance/SDXL-Lightning"
18
  ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
19
 
20
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
21
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
22
+ pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
23
+ unet=unet,
24
+ vae=vae,
25
+ text_encoder=pipe_turbo.text_encoder,
26
+ text_encoder_2=pipe_turbo.text_encoder_2,
27
+ tokenizer=pipe_turbo.tokenizer,
28
+ tokenizer_2=pipe_turbo.tokenizer_2,
29
+ torch_dtype=torch.float16,
30
+ variant="fp16"
31
+ )#.to("cuda")
32
+ del unet
33
  pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
34
  pipe_lightning.to("cuda")
35
 
 
37
  repo_name = "ByteDance/Hyper-SD"
38
  ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
39
 
40
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
41
+ unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
42
+ pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
43
+ unet=unet,
44
+ vae=vae,
45
+ text_encoder=pipe_turbo.text_encoder,
46
+ text_encoder_2=pipe_turbo.text_encoder_2,
47
+ tokenizer=pipe_turbo.tokenizer,
48
+ tokenizer_2=pipe_turbo.tokenizer_2,
49
+ torch_dtype=torch.float16,
50
+ variant="fp16"
51
+ )#.to("cuda")
52
  pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
53
  pipe_hyper.to("cuda")
54
+ del unet
55
 
56
  def run_comparison(prompt):
57
  image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]