File size: 3,425 Bytes
de1bc64
 
 
 
 
 
945d0d7
de1bc64
 
 
 
560e4bf
a6270be
626e60b
560e4bf
 
8c32573
a6270be
 
 
 
 
8c32573
 
e13e27c
3a690f4
de1bc64
 
 
 
 
 
1346e32
de1bc64
1346e32
de1bc64
 
 
 
e6fe518
d15e6e3
 
de1bc64
 
 
 
 
8c32573
668f50f
1cf2dad
 
 
668f50f
fe28baa
8c32573
1346e32
 
 
1fb3c80
de1bc64
 
1346e32
de1bc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c32573
de1bc64
 
 
 
 
 
8c32573
de1bc64
 
560e4bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gc
import os
from typing import TypeAlias

import torch
from PIL.Image import Image
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from transformers import T5EncoderModel, CLIPTextModel, logging 
import torch._dynamo
# torch._dynamo.config.suppress_errors = True

Pipeline: TypeAlias = FluxPipeline

torch.backends.cudnn.benchmark = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True


os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"

TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"


def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer") 
    transformer = FluxTransformer2DModel.from_pretrained(
            path, 
            use_safetensors=False,
            local_files_only=True,
            torch_dtype=torch.bfloat16)

    pipeline = FluxPipeline.from_pretrained(
        CHECKPOINT,
        revision=REVISION,
        transformer=transformer,
        local_files_only=True,
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    pipeline.to(memory_format=torch.channels_last)
    # pipeline.transformer.to(memory_format=torch.channels_last)
    # pipeline.vae.to(memory_format=torch.channels_last)
 
    # quantize_(pipeline.vae, int8_weight_only())
    
    PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    with torch.inference_mode():  
        for _ in range(4):
            pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    torch.cuda.empty_cache()
    return pipeline

@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]


if __name__ == "__main__":
    from time import perf_counter
    PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    request = TextToImageRequest(prompt=PROMPT, 
            height=None,
            width=None,
            seed=666)
    generator = torch.Generator(device="cuda")
    start_time = perf_counter()
    pipe_ = load_pipeline()
    stop_time = perf_counter()
    print(f"Pipeline is loaded in {stop_time - start_time}s")
    for _ in range(4):
        start_time = perf_counter()
        infer(request, pipe_, generator=generator.manual_seed(request.seed))
        stop_time = perf_counter()
        print(f"Request in {stop_time - start_time}s")