Sketcher / sd /utils /utils.py
Gainward777's picture
Update sd/utils/utils.py
5bdf5c2 verified
import torch
from diffusers import (ControlNetModel,
StableDiffusionXLControlNetImg2ImgPipeline,
AutoencoderKL,
T2IAdapter,
StableDiffusionXLAdapterPipeline,
EulerAncestralDiscreteScheduler)
from controlnet_aux.pidi import PidiNetDetector
from PIL import Image
import os
def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
adapter_type="full_adapter_xl"):
if adapter_type == "full_adapter_xl":
return T2IAdapter.from_pretrained(model_name,
subfolder=subfolder,
torch_dtype=torch.float16,
adapter_type=adapter_type)
def get_scheduler(model_name, scheduler_type="discrete"):
if scheduler_type == "discrete":
return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
if model_type == 'pidi':
return PidiNetDetector.from_pretrained(model_name)
def load_lora(pipe, lora_path=None):
if lora_path != None:
try:
lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
lora_name=lora_path.split("/")[-1]
pipe.load_lora_weights(lora_dir, weight_name=lora_name)
except Exception as ex:
print(ex)
#return pipe
def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
if controlnet!=None:
pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16)
load_lora(pipe, lora_path)
return pipe
elif adapter != None:
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
adapter=adapter,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16")
load_lora(pipe, lora_path)
return pipe