File size: 3,110 Bytes
de1bc64
 
 
 
 
 
945d0d7
de1bc64
 
 
 
560e4bf
afb0eb8
560e4bf
 
8c32573
a6270be
8c32573
e13e27c
3a690f4
afb0eb8
de1bc64
 
 
 
 
 
1346e32
de1bc64
1346e32
de1bc64
 
 
 
e6fe518
d15e6e3
 
de1bc64
 
 
 
 
f0c9b97
69e3350
a060357
f0c9b97
fe28baa
8c32573
1346e32
 
 
1fb3c80
de1bc64
 
1346e32
de1bc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c32573
de1bc64
 
 
 
 
 
8c32573
de1bc64
 
560e4bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gc
import os
from typing import TypeAlias

import torch
from PIL.Image import Image
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from transformers import T5EncoderModel, CLIPTextModel, logging 


Pipeline: TypeAlias = FluxPipeline

torch.backends.cudnn.benchmark = True

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"

CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"

TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"


def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer") 
    transformer = FluxTransformer2DModel.from_pretrained(
            path, 
            use_safetensors=False,
            local_files_only=True,
            torch_dtype=torch.bfloat16)

    pipeline = FluxPipeline.from_pretrained(
        CHECKPOINT,
        revision=REVISION,
        transformer=transformer,
        local_files_only=True,
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    quantize_(pipeline.vae, int8_weight_only())
    torch.compile(pipeline.vae, mode="max-autotune-no-cudagraphs", fullgraph=True)
    pipeline.transformer.to(memory_format=torch.channels_last)
    
    PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    with torch.inference_mode():  
        for _ in range(4):
            pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    torch.cuda.empty_cache()
    return pipeline

@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]


if __name__ == "__main__":
    from time import perf_counter
    PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
    request = TextToImageRequest(prompt=PROMPT, 
            height=None,
            width=None,
            seed=666)
    generator = torch.Generator(device="cuda")
    start_time = perf_counter()
    pipe_ = load_pipeline()
    stop_time = perf_counter()
    print(f"Pipeline is loaded in {stop_time - start_time}s")
    for _ in range(4):
        start_time = perf_counter()
        infer(request, pipe_, generator=generator.manual_seed(request.seed))
        stop_time = perf_counter()
        print(f"Request in {stop_time - start_time}s")