manbeast3b commited on
Commit
e5fb4b3
·
verified ·
1 Parent(s): 5cd3d1d

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -2
src/pipeline.py CHANGED
@@ -15,7 +15,7 @@ from diffusers import (
15
  )
16
  from transformers import T5EncoderModel
17
  from huggingface_hub.constants import HF_HUB_CACHE
18
- from torchao.quantization import quantize_, int8_weight_only
19
  from first_block_cache.diffusers_adapters import apply_cache_on_pipe
20
  from pipelines.models import TextToImageRequest
21
  from torch import Generator
@@ -100,7 +100,7 @@ class PipelineManager:
100
  torch_dtype=Config.DTYPE
101
  ).to(memory_format=torch.channels_last)
102
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
103
- vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
104
 
105
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
106
  model = FluxTransformer2DModel.from_pretrained(
@@ -122,6 +122,7 @@ class PipelineManager:
122
  pipeline.to(memory_format=torch.channels_last)
123
  pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
124
  quantize_(pipeline.vae, int8_weight_only())
 
125
  PipelineManager._warmup(pipeline)
126
 
127
  return pipeline
 
15
  )
16
  from transformers import T5EncoderModel
17
  from huggingface_hub.constants import HF_HUB_CACHE
18
+ from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
19
  from first_block_cache.diffusers_adapters import apply_cache_on_pipe
20
  from pipelines.models import TextToImageRequest
21
  from torch import Generator
 
100
  torch_dtype=Config.DTYPE
101
  ).to(memory_format=torch.channels_last)
102
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
103
+ # vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
104
 
105
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
106
  model = FluxTransformer2DModel.from_pretrained(
 
122
  pipeline.to(memory_format=torch.channels_last)
123
  pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
124
  quantize_(pipeline.vae, int8_weight_only())
125
+ quantize_(pipeline.vae, float8_weight_only())
126
  PipelineManager._warmup(pipeline)
127
 
128
  return pipeline