File size: 2,037 Bytes
c193566
abbf98b
 
42c0e61
22a3153
 
 
c193566
 
42c0e61
 
83993a2
abbf98b
c193566
a6f2238
abbf98b
22a3153
c193566
22a3153
42c0e61
 
 
 
 
c193566
0a9e6a7
42c0e61
abbf98b
83993a2
c193566
83993a2
65e1d30
c193566
708f19a
c193566
708f19a
0a9e6a7
22a3153
708f19a
c193566
708f19a
ecfc08b
0a9e6a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from torch import Generator
from diffusers.image_processor import VaeImageProcessor
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
import torch
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
import gc
import os
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel


os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
HOME = os.environ["HOME"]

Pipeline = None
ckpt_id = "black-forest-labs/FLUX.1-schnell"

def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    
def load_pipeline() -> Pipeline:
    empty_cache()
    vae = AutoencoderTiny.from_pretrained("aifeifei798/taef1", torch_dtype=torch.bfloat16)
    model = FluxTransformer2DModel.from_pretrained(f"{HOME}/.cache/huggingface/hub/models--slobers--transgender/snapshots/cb99836efa0ed55856970269c42fafdaa0e44c5d", torch_dtype=torch.bfloat16, use_safetensors=False)
    text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16)
    pipeline = DiffusionPipeline.from_pretrained(ckpt_id, vae=vae, transformer=model, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16)
    pipeline.to("cuda")
    
    for _ in range(2):
        empty_cache()
        pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    return pipeline

def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    empty_cache()
    generator = Generator("cuda").manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)