File size: 3,375 Bytes
22a3153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from diffusers import FluxPipeline, AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import gc
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator

Pipeline = None

CHECKPOINT = "black-forest-labs/FLUX.1-schnell"

def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:
    infer(TextToImageRequest(prompt=""), Pipeline)

    return Pipeline


def encode_prompt(prompt: str):
    text_encoder = CLIPTextModel.from_pretrained(
        CHECKPOINT,
        subfolder="text_encoder",
        torch_dtype=torch.bfloat16,
    )

    text_encoder_2 = T5EncoderModel.from_pretrained(
        CHECKPOINT,
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16,
    )

    tokenizer = CLIPTokenizer.from_pretrained(CHECKPOINT, subfolder="tokenizer")
    tokenizer_2 = T5TokenizerFast.from_pretrained(CHECKPOINT, subfolder="tokenizer_2")

    pipeline = FluxPipeline.from_pretrained(
        CHECKPOINT,
        text_encoder=text_encoder,
        text_encoder_2=text_encoder_2,
        tokenizer=tokenizer,
        tokenizer_2=tokenizer_2,
        transformer=None,
        vae=None,
    ).to("cuda")

    with torch.no_grad():
        return pipeline.encode_prompt(
            prompt=prompt,
            prompt_2=None,
            max_sequence_length=256,
        )


def infer_latents(prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None):
    pipeline = FluxPipeline.from_pretrained(
        CHECKPOINT,
        text_encoder=None,
        text_encoder_2=None,
        tokenizer=None,
        tokenizer_2=None,
        vae=None,
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    if seed is None:
        generator = None
    else:
        generator = Generator(pipeline.device).manual_seed(seed)

    return pipeline(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        num_inference_steps=4,
        guidance_scale=0.0,
        width=width,
        height=height,
        generator=generator,
        output_type="latent",
    ).images


def infer(request: TextToImageRequest, _pipeline: Pipeline) -> Image:
    empty_cache()

    prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(request.prompt)

    empty_cache()

    latents = infer_latents(prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed)

    empty_cache()

    vae = AutoencoderKL.from_pretrained(
        CHECKPOINT,
        subfolder="vae",
        torch_dtype=torch.bfloat16,
    ).to("cuda")

    vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

    height = request.height or 64 * vae_scale_factor
    width = request.width or 64 * vae_scale_factor

    with torch.no_grad():
        latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
        latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

        image = vae.decode(latents, return_dict=False)[0]
        return image_processor.postprocess(image, output_type="pil")[0]