File size: 3,662 Bytes
ae1caa4
 
 
 
 
 
e4f08b5
ae1caa4
 
 
 
e4f08b5
8498d2f
0e9bb8a
8498d2f
ae1caa4
 
8498d2f
e4f08b5
87e712b
 
 
 
8498d2f
 
e4f08b5
 
ae1caa4
 
 
 
 
 
 
 
8498d2f
ae1caa4
 
 
 
4d1a25f
8498d2f
 
3229132
8498d2f
 
e4f08b5
ae1caa4
 
 
 
8498d2f
ae1caa4
 
fa81873
e4f08b5
0e9bb8a
 
8c788eb
e952ca7
34e48c1
8498d2f
 
cedf637
8498d2f
 
 
ae1caa4
 
cedf637
ae1caa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8498d2f
ae1caa4
 
 
 
 
 
8498d2f
ae1caa4
 
e4f08b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gc
import os
from typing import TypeAlias

import torch
from PIL.Image import Image
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from transformers import T5EncoderModel, CLIPTextModel, logging 
import torch._dynamo

torch._dynamo.config.suppress_errors = True

Pipeline: TypeAlias = FluxPipeline

torch.backends.cudnn.benchmark = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True


os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"

TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"


def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer") 
    transformer = FluxTransformer2DModel.from_pretrained(
            path, 
            use_safetensors=False,
            local_files_only=True,
            torch_dtype=torch.bfloat16)
    vae = AutoencoderTiny.from_pretrained(
            TinyVAE,
            revision=TinyVAE_REV,
            local_files_only=True,
            torch_dtype=torch.bfloat16)

    pipeline = FluxPipeline.from_pretrained(
        CHECKPOINT,
        revision=REVISION,
        transformer=transformer,
        vae=vae,
        local_files_only=True,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    pipeline.to(memory_format=torch.channels_last) 
    pipeline.transformer = torch.compile(pipeline.transformer)
    quantize_(pipeline.vae, int8_weight_only())
    # pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
    # pipeline.set_progress_bar_config(disable=True)
    
    PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    with torch.no_grad():  
        for _ in range(4):
            pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    torch.cuda.empty_cache()
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]


if __name__ == "__main__":
    from time import perf_counter
    PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    request = TextToImageRequest(prompt=PROMPT, 
            height=None,
            width=None,
            seed=666)
    generator = torch.Generator(device="cuda")
    start_time = perf_counter()
    pipe_ = load_pipeline()
    stop_time = perf_counter()
    print(f"Pipeline is loaded in {stop_time - start_time}s")
    for _ in range(4):
        start_time = perf_counter()
        infer(request, pipe_, generator=generator.manual_seed(request.seed))
        stop_time = perf_counter()
        print(f"Request in {stop_time - start_time}s")