|
--- |
|
library_name: diffusers |
|
license: apache-2.0 |
|
--- |
|
|
|
int8-wo version of [Flux.1-Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell). |
|
|
|
```python |
|
from diffusers import FluxTransformer2DModel |
|
from torchao.quantization import quantize_, int8_weight_only |
|
import torch |
|
|
|
ckpt_id = "black-forest-labs/FLUX.1-schnell" |
|
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 |
|
) |
|
quantize_(transformer, int8_weight_only()) |
|
output_dir = "./flux-schnell-int8wo" |
|
transformer.save_pretrained(output_dir, safe_serialization=False) |
|
|
|
save_to = "sayakpaul/flux.1-schell-int8wo-improved" |
|
transformer.push_to_hub(save_to, safe_serialization=False) |
|
``` |
|
|
|
Install `diffusers`, `huggingface_hub`, `ao` from the source. |
|
|
|
Inference: |
|
|
|
```python |
|
import torch |
|
from diffusers import FluxTransformer2DModel, DiffusionPipeline |
|
|
|
dtype, device = torch.bfloat16, "cuda" |
|
ckpt_id = "black-forest-labs/FLUX.1-schnell" |
|
|
|
model = FluxTransformer2DModel.from_pretrained( |
|
"sayakpaul/flux.1-schell-int8wo-improved", torch_dtype=dtype, use_safetensors=False |
|
) |
|
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda") |
|
image = pipeline( |
|
"cat", guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 |
|
).images[0] |
|
image.save("flux_schnell_int8.png") |
|
``` |