jokerbit commited on
Commit
f6baeb3
·
verified ·
1 Parent(s): 19cacc0
Files changed (1) hide show
  1. src/pipeline.py +3 -4
src/pipeline.py CHANGED
@@ -13,6 +13,7 @@ from transformers import T5EncoderModel, CLIPTextModel
13
 
14
 
15
  Pipeline: TypeAlias = FluxPipeline
 
16
  torch.backends.cudnn.benchmark = True
17
 
18
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
@@ -37,11 +38,9 @@ def load_pipeline() -> Pipeline:
37
  local_files_only=True,
38
  torch_dtype=torch.bfloat16,
39
  )
40
-
41
- pipeline.to(memory_format=torch.channels_last)
42
- pipeline.enable_vae_slicing()
43
  pipeline.to("cuda")
44
- # quantize_(pipeline.vae, int8_weight_only())
45
  for _ in range(4):
46
  pipeline("cat", num_inference_steps=4)
47
 
 
13
 
14
 
15
  Pipeline: TypeAlias = FluxPipeline
16
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
17
  torch.backends.cudnn.benchmark = True
18
 
19
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
 
38
  local_files_only=True,
39
  torch_dtype=torch.bfloat16,
40
  )
41
+ pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
 
 
42
  pipeline.to("cuda")
43
+
44
  for _ in range(4):
45
  pipeline("cat", num_inference_steps=4)
46