|
import gc |
|
import os |
|
from abc import ABC, abstractmethod |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from controlnet_aux import ( |
|
CannyDetector, |
|
LineartDetector, |
|
MidasDetector, |
|
OpenposeDetector, |
|
PidiNetDetector, |
|
ZoeDetector, |
|
) |
|
from diffusers import ( |
|
AutoencoderKL, |
|
EulerAncestralDiscreteScheduler, |
|
StableDiffusionXLAdapterPipeline, |
|
T2IAdapter, |
|
) |
|
|
|
SD_XL_BASE_RATIOS = { |
|
"0.5": (704, 1408), |
|
"0.52": (704, 1344), |
|
"0.57": (768, 1344), |
|
"0.6": (768, 1280), |
|
"0.68": (832, 1216), |
|
"0.72": (832, 1152), |
|
"0.78": (896, 1152), |
|
"0.82": (896, 1088), |
|
"0.88": (960, 1088), |
|
"0.94": (960, 1024), |
|
"1.0": (1024, 1024), |
|
"1.07": (1024, 960), |
|
"1.13": (1088, 960), |
|
"1.21": (1088, 896), |
|
"1.29": (1152, 896), |
|
"1.38": (1152, 832), |
|
"1.46": (1216, 832), |
|
"1.67": (1280, 768), |
|
"1.75": (1344, 768), |
|
"1.91": (1344, 704), |
|
"2.0": (1408, 704), |
|
"2.09": (1472, 704), |
|
"2.4": (1536, 640), |
|
"2.5": (1600, 640), |
|
"2.89": (1664, 576), |
|
"3.0": (1728, 576), |
|
} |
|
|
|
|
|
def find_closest_aspect_ratio(target_width: int, target_height: int) -> str: |
|
target_ratio = target_width / target_height |
|
closest_ratio = "" |
|
min_difference = float("inf") |
|
|
|
for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items(): |
|
ratio = width / height |
|
difference = abs(target_ratio - ratio) |
|
|
|
if difference < min_difference: |
|
min_difference = difference |
|
closest_ratio = ratio_str |
|
|
|
return closest_ratio |
|
|
|
|
|
def resize_to_closest_aspect_ratio(image: PIL.Image.Image) -> PIL.Image.Image: |
|
target_width, target_height = image.size |
|
closest_ratio = find_closest_aspect_ratio(target_width, target_height) |
|
|
|
|
|
new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio] |
|
|
|
|
|
resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS) |
|
|
|
return resized_image |
|
|
|
|
|
ADAPTER_REPO_IDS = { |
|
"canny": "TencentARC/t2i-adapter-canny-sdxl-1.0", |
|
"sketch": "TencentARC/t2i-adapter-sketch-sdxl-1.0", |
|
"lineart": "TencentARC/t2i-adapter-lineart-sdxl-1.0", |
|
"depth-midas": "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", |
|
"depth-zoe": "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0", |
|
"openpose": "TencentARC/t2i-adapter-openpose-sdxl-1.0", |
|
|
|
} |
|
ADAPTER_NAMES = list(ADAPTER_REPO_IDS.keys()) |
|
|
|
|
|
class Preprocessor(ABC): |
|
@abstractmethod |
|
def to(self, device: torch.device | str) -> "Preprocessor": |
|
pass |
|
|
|
@abstractmethod |
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
pass |
|
|
|
|
|
class CannyPreprocessor(Preprocessor): |
|
def __init__(self): |
|
self.model = CannyDetector() |
|
|
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
return self.model(image, detect_resolution=384, image_resolution=1024) |
|
|
|
|
|
class LineartPreprocessor(Preprocessor): |
|
def __init__(self): |
|
self.model = LineartDetector.from_pretrained("lllyasviel/Annotators") |
|
|
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
self.model.to(device) |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
return self.model(image, detect_resolution=384, image_resolution=1024) |
|
|
|
|
|
class MidasPreprocessor(Preprocessor): |
|
def __init__(self): |
|
self.model = MidasDetector.from_pretrained( |
|
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" |
|
) |
|
|
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
self.model.to(device) |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
return self.model(image, detect_resolution=512, image_resolution=1024) |
|
|
|
|
|
class OpenposePreprocessor(Preprocessor): |
|
def __init__(self): |
|
self.model = OpenposeDetector.from_pretrained("lllyasviel/Annotators") |
|
|
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
self.model.to(device) |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
out = self.model(image, detect_resolution=512, image_resolution=1024) |
|
out = np.array(out)[:, :, ::-1] |
|
out = PIL.Image.fromarray(np.uint8(out)) |
|
return out |
|
|
|
|
|
class PidiNetPreprocessor(Preprocessor): |
|
def __init__(self): |
|
self.model = PidiNetDetector.from_pretrained("lllyasviel/Annotators") |
|
|
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
self.model.to(device) |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
return self.model(image, detect_resolution=512, image_resolution=1024, apply_filter=True) |
|
|
|
|
|
class RecolorPreprocessor(Preprocessor): |
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
return image.convert("L").convert("RGB") |
|
|
|
|
|
class ZoePreprocessor(Preprocessor): |
|
def __init__(self): |
|
self.model = ZoeDetector.from_pretrained( |
|
"valhalla/t2iadapter-aux-models", filename="zoed_nk.pth", model_type="zoedepth_nk" |
|
) |
|
|
|
def to(self, device: torch.device | str) -> Preprocessor: |
|
self.model.to(device) |
|
return self |
|
|
|
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
return self.model(image, gamma_corrected=True, image_resolution=1024) |
|
|
|
|
|
PRELOAD_PREPROCESSORS_IN_GPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_GPU_MEMORY", "0") == "1" |
|
PRELOAD_PREPROCESSORS_IN_CPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_CPU_MEMORY", "0") == "1" |
|
if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
preprocessors_gpu: dict[str, Preprocessor] = { |
|
"canny": CannyPreprocessor().to(device), |
|
"sketch": PidiNetPreprocessor().to(device), |
|
"lineart": LineartPreprocessor().to(device), |
|
"depth-midas": MidasPreprocessor().to(device), |
|
"depth-zoe": ZoePreprocessor().to(device), |
|
"openpose": OpenposePreprocessor().to(device), |
|
"recolor": RecolorPreprocessor().to(device), |
|
} |
|
|
|
def get_preprocessor(adapter_name: str) -> Preprocessor: |
|
return preprocessors_gpu[adapter_name] |
|
|
|
elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY: |
|
preprocessors_cpu: dict[str, Preprocessor] = { |
|
"canny": CannyPreprocessor(), |
|
"sketch": PidiNetPreprocessor(), |
|
"lineart": LineartPreprocessor(), |
|
"depth-midas": MidasPreprocessor(), |
|
"depth-zoe": ZoePreprocessor(), |
|
"openpose": OpenposePreprocessor(), |
|
"recolor": RecolorPreprocessor(), |
|
} |
|
|
|
def get_preprocessor(adapter_name: str) -> Preprocessor: |
|
return preprocessors_cpu[adapter_name] |
|
|
|
else: |
|
|
|
def get_preprocessor(adapter_name: str) -> Preprocessor: |
|
if adapter_name == "canny": |
|
return CannyPreprocessor() |
|
elif adapter_name == "sketch": |
|
return PidiNetPreprocessor() |
|
elif adapter_name == "lineart": |
|
return LineartPreprocessor() |
|
elif adapter_name == "depth-midas": |
|
return MidasPreprocessor() |
|
elif adapter_name == "depth-zoe": |
|
return ZoePreprocessor() |
|
elif adapter_name == "openpose": |
|
return OpenposePreprocessor() |
|
elif adapter_name == "recolor": |
|
return RecolorPreprocessor() |
|
else: |
|
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}") |
|
|
|
def download_all_preprocessors(): |
|
for adapter_name in ADAPTER_NAMES: |
|
get_preprocessor(adapter_name) |
|
gc.collect() |
|
|
|
download_all_preprocessors() |
|
|
|
|
|
def download_all_adapters(): |
|
for adapter_name in ADAPTER_NAMES: |
|
T2IAdapter.from_pretrained( |
|
ADAPTER_REPO_IDS[adapter_name], |
|
torch_dtype=torch.float16, |
|
varient="fp16", |
|
) |
|
gc.collect() |
|
|
|
|
|
class Model: |
|
MAX_NUM_INFERENCE_STEPS = 50 |
|
|
|
def __init__(self, adapter_name: str): |
|
if adapter_name not in ADAPTER_NAMES: |
|
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}") |
|
|
|
self.preprocessor_name = adapter_name |
|
self.adapter_name = adapter_name |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if torch.cuda.is_available(): |
|
self.preprocessor = get_preprocessor(adapter_name).to(self.device) |
|
|
|
model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
adapter = T2IAdapter.from_pretrained( |
|
ADAPTER_REPO_IDS[adapter_name], |
|
torch_dtype=torch.float16, |
|
varient="fp16", |
|
).to(self.device) |
|
self.pipe = StableDiffusionXLAdapterPipeline.from_pretrained( |
|
model_id, |
|
vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16), |
|
adapter=adapter, |
|
scheduler=EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler"), |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
).to(self.device) |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe.load_lora_weights( |
|
"stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors" |
|
) |
|
self.pipe.fuse_lora(lora_scale=0.4) |
|
else: |
|
self.preprocessor = None |
|
self.pipe = None |
|
|
|
def change_preprocessor(self, adapter_name: str) -> None: |
|
if adapter_name not in ADAPTER_NAMES: |
|
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}") |
|
if adapter_name == self.preprocessor_name: |
|
return |
|
|
|
if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY: |
|
pass |
|
elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY: |
|
self.preprocessor.to("cpu") |
|
else: |
|
del self.preprocessor |
|
self.preprocessor = get_preprocessor(adapter_name).to(self.device) |
|
self.preprocessor_name = adapter_name |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def change_adapter(self, adapter_name: str) -> None: |
|
if adapter_name not in ADAPTER_NAMES: |
|
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}") |
|
if adapter_name == self.adapter_name: |
|
return |
|
self.pipe.adapter = T2IAdapter.from_pretrained( |
|
ADAPTER_REPO_IDS[adapter_name], |
|
torch_dtype=torch.float16, |
|
varient="fp16", |
|
).to(self.device) |
|
self.adapter_name = adapter_name |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def resize_image(self, image: PIL.Image.Image) -> PIL.Image.Image: |
|
w, h = image.size |
|
scale = 1024 / max(w, h) |
|
new_w = int(w * scale) |
|
new_h = int(h * scale) |
|
return image.resize((new_w, new_h), PIL.Image.LANCZOS) |
|
|
|
def run( |
|
self, |
|
image: PIL.Image.Image, |
|
prompt: str, |
|
negative_prompt: str, |
|
adapter_name: str, |
|
num_inference_steps: int = 30, |
|
guidance_scale: float = 5.0, |
|
adapter_conditioning_scale: float = 1.0, |
|
adapter_conditioning_factor: float = 1.0, |
|
seed: int = 0, |
|
apply_preprocess: bool = True, |
|
) -> list[PIL.Image.Image]: |
|
if not torch.cuda.is_available(): |
|
raise RuntimeError("This demo does not work on CPU.") |
|
if num_inference_steps > self.MAX_NUM_INFERENCE_STEPS: |
|
raise ValueError(f"Number of steps must be less than {self.MAX_NUM_INFERENCE_STEPS}") |
|
|
|
|
|
image = self.resize_image(image) |
|
|
|
self.change_preprocessor(adapter_name) |
|
self.change_adapter(adapter_name) |
|
|
|
if apply_preprocess: |
|
image = self.preprocessor(image) |
|
|
|
image = resize_to_closest_aspect_ratio(image) |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
out = self.pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=image, |
|
num_inference_steps=num_inference_steps, |
|
adapter_conditioning_scale=adapter_conditioning_scale, |
|
adapter_conditioning_factor=adapter_conditioning_factor, |
|
generator=generator, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
return [image, out] |