sdxl-optimized / src /pipeline.py
Ashley Wright
Load from folder instead of cached repository
7a19876
raw
history blame
831 Bytes
import torch
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline
from pipelines.models import TextToImageRequest
from torch import Generator
def load_pipeline() -> StableDiffusionXLPipeline:
pipeline = StableDiffusionXLPipeline.from_pretrained(
"./models/newdream-sdxl-20",
torch_dtype=torch.float16,
local_files_only=True,
).to("cuda")
pipeline(prompt="")
return pipeline
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed) if request.seed else None
return pipeline(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
).images[0]