File size: 3,421 Bytes
2a95fa7
 
 
 
 
 
7b00150
2a95fa7
 
 
 
7b00150
2a95fa7
 
7b00150
2a95fa7
7b00150
 
 
 
2a95fa7
7b00150
 
 
2a95fa7
 
 
 
 
 
 
 
7b00150
2a95fa7
 
 
 
7b00150
19cacc0
2a95fa7
 
 
 
 
 
7b00150
2a95fa7
6acd771
18611d2
7b00150
c52db22
b32e748
7b00150
 
 
 
 
2a95fa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b00150
2a95fa7
 
 
 
 
 
7b00150
2a95fa7
 
7b00150
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gc
import os
from typing import TypeAlias

import torch
from PIL.Image import Image
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from transformers import T5EncoderModel, CLIPTextModel, logging 

Pipeline: TypeAlias = FluxPipeline

torch.backends.cudnn.benchmark = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True


os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"

TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"


def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer") 
    transformer = FluxTransformer2DModel.from_pretrained(
            path, 
            use_safetensors=False,
            local_files_only=True,
            torch_dtype=torch.bfloat16)

    pipeline = FluxPipeline.from_pretrained(
        CHECKPOINT,
        revision=REVISION,
        transformer=transformer,
        local_files_only=True,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    pipeline.to(memory_format=torch.channels_last)
    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
    quantize_(pipeline.vae, int8_weight_only())
    pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
        
    PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    with torch.inference_mode():  
        for _ in range(4):
            pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    torch.cuda.empty_cache()
    return pipeline

@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]


if __name__ == "__main__":
    from time import perf_counter
    PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    request = TextToImageRequest(prompt=PROMPT, 
            height=None,
            width=None,
            seed=666)
    generator = torch.Generator(device="cuda")
    start_time = perf_counter()
    pipe_ = load_pipeline()
    stop_time = perf_counter()
    print(f"Pipeline is loaded in {stop_time - start_time}s")
    for _ in range(4):
        start_time = perf_counter()
        infer(request, pipe_, generator=generator.manual_seed(request.seed))
        stop_time = perf_counter()
        print(f"Request in {stop_time - start_time}s")