Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,919 Bytes
5bdf5c2 9b843da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from diffusers import (ControlNetModel,
StableDiffusionXLControlNetImg2ImgPipeline,
AutoencoderKL,
T2IAdapter,
StableDiffusionXLAdapterPipeline,
EulerAncestralDiscreteScheduler)
from controlnet_aux.pidi import PidiNetDetector
from PIL import Image
import os
def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
adapter_type="full_adapter_xl"):
if adapter_type == "full_adapter_xl":
return T2IAdapter.from_pretrained(model_name,
subfolder=subfolder,
torch_dtype=torch.float16,
adapter_type=adapter_type)
def get_scheduler(model_name, scheduler_type="discrete"):
if scheduler_type == "discrete":
return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
if model_type == 'pidi':
return PidiNetDetector.from_pretrained(model_name)
def load_lora(pipe, lora_path=None):
if lora_path != None:
try:
lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
lora_name=lora_path.split("/")[-1]
pipe.load_lora_weights(lora_dir, weight_name=lora_name)
except Exception as ex:
print(ex)
#return pipe
def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
if controlnet!=None:
pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16)
load_lora(pipe, lora_path)
return pipe
elif adapter != None:
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
adapter=adapter,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16")
load_lora(pipe, lora_path)
return pipe |