zhiweili commited on
Commit
812f35c
Β·
1 Parent(s): b586df1

change vae

Browse files
Files changed (1) hide show
  1. app_onediff.py +4 -7
app_onediff.py CHANGED
@@ -8,11 +8,9 @@ import json
8
  from diffusers import (
9
  DDPMScheduler,
10
  AutoPipelineForText2Image,
11
- AutoencoderTiny,
12
  )
13
 
14
- import sys
15
-
16
  os.system("python3 -m pip --no-cache-dir install --pre nexfort -f https://github.com/siliconflow/nexfort_releases/releases/expanded_assets/torch2.4.1_cu121")
17
  os.system("git clone https://github.com/siliconflow/onediff.git")
18
  os.system("cd onediff && python3 -m pip install .")
@@ -33,11 +31,10 @@ def nexfort_compile(torch_module: torch.nn.Module):
33
  BASE_MODEL = "stabilityai/sdxl-turbo"
34
  device = "cuda"
35
 
36
- vae = AutoencoderTiny.from_pretrained(
37
- 'madebyollin/taesdxl',
38
- use_safetensors=True,
39
  torch_dtype=torch.float16,
40
- ).to('cuda')
41
  base_pipe = AutoPipelineForText2Image.from_pretrained(
42
  BASE_MODEL,
43
  vae=vae,
 
8
  from diffusers import (
9
  DDPMScheduler,
10
  AutoPipelineForText2Image,
11
+ AutoencoderKL,
12
  )
13
 
 
 
14
  os.system("python3 -m pip --no-cache-dir install --pre nexfort -f https://github.com/siliconflow/nexfort_releases/releases/expanded_assets/torch2.4.1_cu121")
15
  os.system("git clone https://github.com/siliconflow/onediff.git")
16
  os.system("cd onediff && python3 -m pip install .")
 
31
  BASE_MODEL = "stabilityai/sdxl-turbo"
32
  device = "cuda"
33
 
34
+ vae = AutoencoderKL.from_pretrained(
35
+ "madebyollin/sdxl-vae-fp16-fix",
 
36
  torch_dtype=torch.float16,
37
+ )
38
  base_pipe = AutoPipelineForText2Image.from_pretrained(
39
  BASE_MODEL,
40
  vae=vae,