jokerbit commited on
Commit
1f060f6
·
verified ·
1 Parent(s): 62d547b

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -3
src/pipeline.py CHANGED
@@ -11,7 +11,7 @@ from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  from functools import partial
14
-
15
  my_partial_compile = partial(torch.compile, mode="max-autotune")
16
 
17
  Pipeline: TypeAlias = FluxPipeline
@@ -52,9 +52,9 @@ def load_pipeline() -> Pipeline:
52
  ).to("cuda")
53
 
54
  pipeline.to(memory_format=torch.channels_last)
55
- quantize_(pipeline.vae, int8_weight_only())
56
  pipeline.vae = my_partial_compile(pipeline.vae)
57
- pipeline.transformer = torch.compile(pipeline.transformer)
58
  with torch.inference_mode():
59
  for _ in range(2):
60
  pipeline("cats running on a road with a dog chasing", num_inference_steps=4)
 
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  from functools import partial
14
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
15
  my_partial_compile = partial(torch.compile, mode="max-autotune")
16
 
17
  Pipeline: TypeAlias = FluxPipeline
 
52
  ).to("cuda")
53
 
54
  pipeline.to(memory_format=torch.channels_last)
55
+ # quantize_(pipeline.vae, int8_weight_only())
56
  pipeline.vae = my_partial_compile(pipeline.vae)
57
+ apply_cache_on_pipe(pipeline, residual_diff_threshold=0.25)
58
  with torch.inference_mode():
59
  for _ in range(2):
60
  pipeline("cats running on a road with a dog chasing", num_inference_steps=4)