File size: 3,110 Bytes
de1bc64 945d0d7 de1bc64 560e4bf afb0eb8 560e4bf 8c32573 a6270be 8c32573 e13e27c 3a690f4 afb0eb8 de1bc64 1346e32 de1bc64 1346e32 de1bc64 e6fe518 d15e6e3 de1bc64 f0c9b97 69e3350 a060357 f0c9b97 fe28baa 8c32573 1346e32 1fb3c80 de1bc64 1346e32 de1bc64 8c32573 de1bc64 8c32573 de1bc64 560e4bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gc
import os
from typing import TypeAlias
import torch
from PIL.Image import Image
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from torch import Generator
from torchao.quantization import quantize_, int8_weight_only
from transformers import T5EncoderModel, CLIPTextModel, logging
Pipeline: TypeAlias = FluxPipeline
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"
TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
def load_pipeline() -> Pipeline:
path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
transformer = FluxTransformer2DModel.from_pretrained(
path,
use_safetensors=False,
local_files_only=True,
torch_dtype=torch.bfloat16)
pipeline = FluxPipeline.from_pretrained(
CHECKPOINT,
revision=REVISION,
transformer=transformer,
local_files_only=True,
torch_dtype=torch.bfloat16,
).to("cuda")
quantize_(pipeline.vae, int8_weight_only())
torch.compile(pipeline.vae, mode="max-autotune-no-cudagraphs", fullgraph=True)
pipeline.transformer.to(memory_format=torch.channels_last)
PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
with torch.inference_mode():
for _ in range(4):
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
torch.cuda.empty_cache()
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: torch.Generator) -> Image:
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]
if __name__ == "__main__":
from time import perf_counter
PROMPT = 'martyr, semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
request = TextToImageRequest(prompt=PROMPT,
height=None,
width=None,
seed=666)
generator = torch.Generator(device="cuda")
start_time = perf_counter()
pipe_ = load_pipeline()
stop_time = perf_counter()
print(f"Pipeline is loaded in {stop_time - start_time}s")
for _ in range(4):
start_time = perf_counter()
infer(request, pipe_, generator=generator.manual_seed(request.seed))
stop_time = perf_counter()
print(f"Request in {stop_time - start_time}s")
|