Spaces:
Running
on
Zero
Running
on
Zero
Update src/flux/pipeline_tools.py
Browse files
src/flux/pipeline_tools.py
CHANGED
|
@@ -510,7 +510,8 @@ class CustomFluxPipeline:
|
|
| 510 |
ckpt_root_condition=None,
|
| 511 |
torch_dtype=torch.bfloat16,
|
| 512 |
):
|
| 513 |
-
|
|
|
|
| 514 |
print("[CustomFluxPipeline] Loading FLUX Pipeline")
|
| 515 |
self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
|
| 516 |
self.pipe.enable_sequential_cpu_offload()
|
|
@@ -518,8 +519,8 @@ class CustomFluxPipeline:
|
|
| 518 |
self.config = config
|
| 519 |
self.device = device
|
| 520 |
self.dtype = torch_dtype
|
| 521 |
-
if config["model"].get("dit_quant", "None") != "None":
|
| 522 |
-
|
| 523 |
|
| 524 |
self.modulation_adapters = []
|
| 525 |
self.pipe.modulation_adapters = []
|
|
|
|
| 510 |
ckpt_root_condition=None,
|
| 511 |
torch_dtype=torch.bfloat16,
|
| 512 |
):
|
| 513 |
+
|
| 514 |
+
model_path = os.getenv("FLUX_MODEL_PATH", "diffusers/FLUX.1-dev-torchao-int8" if config["model"].get("dit_quant", "None")=="int8-quanto" else "black-forest-labs/FLUX.1-dev")
|
| 515 |
print("[CustomFluxPipeline] Loading FLUX Pipeline")
|
| 516 |
self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
|
| 517 |
self.pipe.enable_sequential_cpu_offload()
|
|
|
|
| 519 |
self.config = config
|
| 520 |
self.device = device
|
| 521 |
self.dtype = torch_dtype
|
| 522 |
+
# if config["model"].get("dit_quant", "None") != "None":
|
| 523 |
+
# quantization(self.pipe, config["model"]["dit_quant"])
|
| 524 |
|
| 525 |
self.modulation_adapters = []
|
| 526 |
self.pipe.modulation_adapters = []
|