Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from diffusers import (ControlNetModel, | |
StableDiffusionXLControlNetImg2ImgPipeline, | |
AutoencoderKL, | |
T2IAdapter, | |
StableDiffusionXLAdapterPipeline, | |
EulerAncestralDiscreteScheduler) | |
from controlnet_aux.pidi import PidiNetDetector | |
from PIL import Image | |
import os | |
def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"): | |
return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16) | |
def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"): | |
return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16) | |
def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0", | |
adapter_type="full_adapter_xl"): | |
if adapter_type == "full_adapter_xl": | |
return T2IAdapter.from_pretrained(model_name, | |
subfolder=subfolder, | |
torch_dtype=torch.float16, | |
adapter_type=adapter_type) | |
def get_scheduler(model_name, scheduler_type="discrete"): | |
if scheduler_type == "discrete": | |
return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler") | |
def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'): | |
if model_type == 'pidi': | |
return PidiNetDetector.from_pretrained(model_name) | |
def load_lora(pipe, lora_path=None): | |
if lora_path != None: | |
try: | |
lora_dir='./'+'/'.join(lora_path.split("/")[:-1]) | |
lora_name=lora_path.split("/")[-1] | |
pipe.load_lora_weights(lora_dir, weight_name=lora_name) | |
except Exception as ex: | |
print(ex) | |
#return pipe | |
def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None): | |
if controlnet!=None: | |
pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name, | |
controlnet=controlnet, | |
vae=vae, | |
torch_dtype=torch.float16) | |
load_lora(pipe, lora_path) | |
return pipe | |
elif adapter != None: | |
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name, | |
adapter=adapter, | |
vae=vae, | |
scheduler=scheduler, | |
torch_dtype=torch.float16, | |
variant="fp16") | |
load_lora(pipe, lora_path) | |
return pipe |