File size: 3,048 Bytes
c6600c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db2ccd0
 
 
c6600c2
 
29527dc
c6600c2
db2ccd0
b3e6321
c6600c2
db2ccd0
 
 
 
ce00f00
 
 
c6600c2
b3e6321
c6600c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#6
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True

Pipeline = None
ckpt_id = "manbeast3b/flux.1-schnell-full1"
ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"


def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
    transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
    pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,)
    
    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
    # basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell_eagle5_1_0.1_unst_7_2k/snapshots/b7a5ce1313327009093d3178220267d0cf669b76")
    # basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell_eagle5_1_0.1_unst_8/snapshots/3666a458a53e7dc83adfecb0bf955a0b4d575843")
    # basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell_eagle5_1_0.1_unst_13/snapshots/b3bdda899cd1961ec9b97bffde3ded31afa73ce3")
    # basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell_eagle5_1_0.1_unst_10/snapshots/20e4cf6ce3cc658237dfd6aae1d5f14bc6b3d1a4")
    basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell_eagle5_1_0.1_unst_8_try/snapshots/0d3ce1d07195ccfe8eafe821ee80b34d74a3c2d7")
    pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
    pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
    quantize_(pipeline.vae, int8_weight_only())
    pipeline.to("cuda")
    for _ in range(3):
        pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]