File size: 3,421 Bytes
2a95fa7 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 19cacc0 2a95fa7 7b00150 2a95fa7 6acd771 18611d2 7b00150 c52db22 b32e748 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 2a95fa7 7b00150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gc
import os
from typing import TypeAlias
import torch
from PIL.Image import Image
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from transformers import T5EncoderModel, CLIPTextModel, logging
Pipeline: TypeAlias = FluxPipeline
torch.backends.cudnn.benchmark = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"
TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
def load_pipeline() -> Pipeline:
path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
transformer = FluxTransformer2DModel.from_pretrained(
path,
use_safetensors=False,
local_files_only=True,
torch_dtype=torch.bfloat16)
pipeline = FluxPipeline.from_pretrained(
CHECKPOINT,
revision=REVISION,
transformer=transformer,
local_files_only=True,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
quantize_(pipeline.vae, int8_weight_only())
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
with torch.inference_mode():
for _ in range(4):
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
torch.cuda.empty_cache()
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]
if __name__ == "__main__":
from time import perf_counter
PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
request = TextToImageRequest(prompt=PROMPT,
height=None,
width=None,
seed=666)
generator = torch.Generator(device="cuda")
start_time = perf_counter()
pipe_ = load_pipeline()
stop_time = perf_counter()
print(f"Pipeline is loaded in {stop_time - start_time}s")
for _ in range(4):
start_time = perf_counter()
infer(request, pipe_, generator=generator.manual_seed(request.seed))
stop_time = perf_counter()
print(f"Request in {stop_time - start_time}s")
|