import os from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel from diffusers.image_processor import VaeImageProcessor from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config import torch import gc from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator from torchao.quantization import quantize_, int8_weight_only, int8_dynamic_activation_int8_weight from time import perf_counter HOME = os.environ["HOME"] os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" FLUX_CHECKPOINT = "black-forest-labs/FLUX.1-schnell" # QUANTIZED_MODEL = [] QUANTIZED_MODEL = ["transformer", "text_encoder_2", "text_encoder", "vae"] QUANT_CONFIG = int8_weight_only() DTYPE = torch.bfloat16 NUM_STEPS = 4 def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None): if quant_ckpt is not None: config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer") model = FluxTransformer2DModel.from_config(config).to(DTYPE) state_dict = torch.load(quant_ckpt, map_location="cpu") model.load_state_dict(state_dict, assign=True) print(f"Loaded {quant_ckpt}") return model model = FluxTransformer2DModel.from_pretrained( FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE ) if quantize: quantize_(model, quant_config) return model def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None): if quant_ckpt is not None: config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder") model = CLIPTextModel(config).to(DTYPE) state_dict = torch.load(quant_ckpt, map_location="cpu") model.load_state_dict(state_dict, assign=True) print(f"Loaded {quant_ckpt}") return model model = CLIPTextModel.from_pretrained( FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE ) if quantize: quantize_(model, quant_config) return model def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None): if quant_ckpt is not None: config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2") model = T5EncoderModel(config).to(DTYPE) state_dict = torch.load(quant_ckpt, map_location="cpu") print(f"Loaded {quant_ckpt}") model.load_state_dict(state_dict, assign=True) return model model = T5EncoderModel.from_pretrained( FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE ) if quantize: quantize_(model, quant_config) return model def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None): if quant_ckpt is not None: config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae") model = AutoencoderKL.from_config(config).to(DTYPE) state_dict = torch.load(quant_ckpt, map_location="cpu") model.load_state_dict(state_dict, assign=True) print(f"Loaded {quant_ckpt}") return model model = AutoencoderKL.from_pretrained( FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE ) if quantize: quantize_(model, quant_config) return model def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> FluxPipeline: empty_cache() pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT, torch_dtype=DTYPE) pipe.text_encoder_2.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last) pipe.vae.to(memory_format=torch.channels_last) pipe.vae = torch.compile(pipe.vae) pipe._exclude_from_cpu_offload = ["vae"] pipe.enable_sequential_cpu_offload() empty_cache() pipe("cat", guidance_scale=0., max_sequence_length=256, num_inference_steps=4) return pipe @torch.inference_mode() def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image: if request.seed is None: generator = None else: generator = Generator(device="cuda").manual_seed(request.seed) empty_cache() image = _pipeline(prompt=request.prompt, width=request.width, height=request.height, guidance_scale=0.0, generator=generator, output_type="pil", max_sequence_length=256, num_inference_steps=NUM_STEPS).images[0] return image