File size: 4,853 Bytes
f3a3ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config
import torch
import gc
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only, int8_dynamic_activation_int8_weight
from time import perf_counter


HOME = os.environ["HOME"]
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
FLUX_CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
# QUANTIZED_MODEL = []
QUANTIZED_MODEL = ["transformer", "text_encoder_2", "text_encoder", "vae"]


QUANT_CONFIG = int8_weight_only()
DTYPE = torch.bfloat16
NUM_STEPS = 4

def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
    if quant_ckpt is not None:
        config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer")
        model = FluxTransformer2DModel.from_config(config).to(DTYPE)
        state_dict = torch.load(quant_ckpt, map_location="cpu")
        model.load_state_dict(state_dict, assign=True)
        print(f"Loaded {quant_ckpt}")
        return model
    
    model = FluxTransformer2DModel.from_pretrained(
        FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE
        )
    if quantize:
        quantize_(model, quant_config)
    return model


def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
    if quant_ckpt is not None:
        config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder")
        model = CLIPTextModel(config).to(DTYPE)
        state_dict = torch.load(quant_ckpt, map_location="cpu")
        model.load_state_dict(state_dict, assign=True)
        print(f"Loaded {quant_ckpt}")
        return model
    
    model = CLIPTextModel.from_pretrained(
        FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE
        )
    if quantize:
        quantize_(model, quant_config)
    return model


def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
    if quant_ckpt is not None:
        config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2")
        model = T5EncoderModel(config).to(DTYPE)
        state_dict = torch.load(quant_ckpt, map_location="cpu")
        print(f"Loaded {quant_ckpt}")
        model.load_state_dict(state_dict, assign=True)
        return model
    
    model = T5EncoderModel.from_pretrained(
        FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE
        )
    if quantize:
        quantize_(model, quant_config)
    return model


def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
    if quant_ckpt is not None:
        config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae")
        model = AutoencoderKL.from_config(config).to(DTYPE)
        state_dict = torch.load(quant_ckpt, map_location="cpu")
        model.load_state_dict(state_dict, assign=True)
        print(f"Loaded {quant_ckpt}")
        return model
    model = AutoencoderKL.from_pretrained(
        FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE
        )
    if quantize:
        quantize_(model, quant_config)
    return model


def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


def load_pipeline() -> FluxPipeline:
    empty_cache()
    
    pipe = FluxPipeline.from_pretrained(FLUX_CHECKPOINT,
                                        torch_dtype=DTYPE)
    
    pipe.text_encoder_2.to(memory_format=torch.channels_last)
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.vae.to(memory_format=torch.channels_last)
    pipe.vae = torch.compile(pipe.vae)
    
    pipe._exclude_from_cpu_offload = ["vae"]
    
    pipe.enable_sequential_cpu_offload()
    
    empty_cache()
    pipe("cat", guidance_scale=0., max_sequence_length=256, num_inference_steps=4)
    return pipe

@torch.inference_mode()
def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
    if request.seed is None:
        generator = None
    else:
        generator = Generator(device="cuda").manual_seed(request.seed)

    empty_cache()
    image = _pipeline(prompt=request.prompt,
                      width=request.width,
                      height=request.height,
                      guidance_scale=0.0,
                      generator=generator,
                      output_type="pil",
                      max_sequence_length=256,
                      num_inference_steps=NUM_STEPS).images[0]
    return image