File size: 1,834 Bytes
1d9dd40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from time import perf_counter
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, AutoencoderTiny, UNet2DConditionModel
from pipelines.models import TextToImageRequest
from torch import Generator
from sfast.compilers.diffusion_pipeline_compiler import (compile,
                                                         CompilationConfig)



def load_pipeline() -> StableDiffusionXLPipeline:
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        "./models/newdream-sdxl-20/",
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
        local_files_only=True,)
    pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
    pipeline.scheduler.config)
    pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
    pipeline.to("cuda")

    config = CompilationConfig.Default()

    try:
        import xformers
        config.enable_xformers = True
    except ImportError:
        print('xformers not installed, skip')
    try:
        import triton
        config.enable_triton = True
    except ImportError:
        print('Triton not installed, skip')

    pipeline  = compile(pipeline, config)
    for _ in range(4):
        pipeline(prompt="", num_inference_steps=15,)

    return pipeline

def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed) if request.seed else None

    return pipeline(
        prompt=request.prompt,
        negative_prompt=request.negative_prompt,
        width=request.width,
        height=request.height,
        generator=generator,
        num_inference_steps=8,
    ).images[0]