Update src/pipeline.py
Browse files- src/pipeline.py +3 -2
src/pipeline.py
CHANGED
@@ -15,7 +15,7 @@ from diffusers import (
|
|
15 |
)
|
16 |
from transformers import T5EncoderModel
|
17 |
from huggingface_hub.constants import HF_HUB_CACHE
|
18 |
-
from torchao.quantization import quantize_, int8_weight_only
|
19 |
from first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
20 |
from pipelines.models import TextToImageRequest
|
21 |
from torch import Generator
|
@@ -100,7 +100,7 @@ class PipelineManager:
|
|
100 |
torch_dtype=Config.DTYPE
|
101 |
).to(memory_format=torch.channels_last)
|
102 |
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
|
103 |
-
vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
|
104 |
|
105 |
path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
|
106 |
model = FluxTransformer2DModel.from_pretrained(
|
@@ -122,6 +122,7 @@ class PipelineManager:
|
|
122 |
pipeline.to(memory_format=torch.channels_last)
|
123 |
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
|
124 |
quantize_(pipeline.vae, int8_weight_only())
|
|
|
125 |
PipelineManager._warmup(pipeline)
|
126 |
|
127 |
return pipeline
|
|
|
15 |
)
|
16 |
from transformers import T5EncoderModel
|
17 |
from huggingface_hub.constants import HF_HUB_CACHE
|
18 |
+
from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
|
19 |
from first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
20 |
from pipelines.models import TextToImageRequest
|
21 |
from torch import Generator
|
|
|
100 |
torch_dtype=Config.DTYPE
|
101 |
).to(memory_format=torch.channels_last)
|
102 |
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d", torch_dtype=Config.DTYPE)
|
103 |
+
# vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
|
104 |
|
105 |
path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
|
106 |
model = FluxTransformer2DModel.from_pretrained(
|
|
|
122 |
pipeline.to(memory_format=torch.channels_last)
|
123 |
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
|
124 |
quantize_(pipeline.vae, int8_weight_only())
|
125 |
+
quantize_(pipeline.vae, float8_weight_only())
|
126 |
PipelineManager._warmup(pipeline)
|
127 |
|
128 |
return pipeline
|