Spaces:
Runtime error
Runtime error
from typing import Literal, Union, Optional | |
import torch | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection | |
from diffusers import ( | |
UNet2DConditionModel, | |
SchedulerMixin, | |
StableDiffusionPipeline, | |
StableDiffusionXLPipeline, | |
) | |
from diffusers.schedulers import ( | |
DDIMScheduler, | |
DDPMScheduler, | |
LMSDiscreteScheduler, | |
EulerAncestralDiscreteScheduler, | |
) | |
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" | |
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" | |
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] | |
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] | |
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this | |
def load_diffusers_model( | |
pretrained_model_name_or_path: str, | |
v2: bool = False, | |
clip_skip: Optional[int] = None, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: | |
# VAE ใฏใใใชใ | |
if v2: | |
tokenizer = CLIPTokenizer.from_pretrained( | |
TOKENIZER_V2_MODEL_NAME, | |
subfolder="tokenizer", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
# default is clip skip 2 | |
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
else: | |
tokenizer = CLIPTokenizer.from_pretrained( | |
TOKENIZER_V1_MODEL_NAME, | |
subfolder="tokenizer", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
return tokenizer, text_encoder, unet | |
def load_checkpoint_model( | |
checkpoint_path: str, | |
v2: bool = False, | |
clip_skip: Optional[int] = None, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: | |
pipe = StableDiffusionPipeline.from_ckpt( | |
checkpoint_path, | |
upcast_attention=True if v2 else False, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
unet = pipe.unet | |
tokenizer = pipe.tokenizer | |
text_encoder = pipe.text_encoder | |
if clip_skip is not None: | |
if v2: | |
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) | |
else: | |
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) | |
del pipe | |
return tokenizer, text_encoder, unet | |
def load_models( | |
pretrained_model_name_or_path: str, | |
scheduler_name: AVAILABLE_SCHEDULERS, | |
v2: bool = False, | |
v_pred: bool = False, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: | |
if pretrained_model_name_or_path.endswith( | |
".ckpt" | |
) or pretrained_model_name_or_path.endswith(".safetensors"): | |
tokenizer, text_encoder, unet = load_checkpoint_model( | |
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype | |
) | |
else: # diffusers | |
tokenizer, text_encoder, unet = load_diffusers_model( | |
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype | |
) | |
# VAE ใฏใใใชใ | |
scheduler = create_noise_scheduler( | |
scheduler_name, | |
prediction_type="v_prediction" if v_pred else "epsilon", | |
) | |
return tokenizer, text_encoder, unet, scheduler | |
def load_diffusers_model_xl( | |
pretrained_model_name_or_path: str, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: | |
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet | |
tokenizers = [ | |
CLIPTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
), | |
CLIPTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer_2", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
pad_token_id=0, # same as open clip | |
), | |
] | |
text_encoders = [ | |
CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
), | |
CLIPTextModelWithProjection.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder_2", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
), | |
] | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
return tokenizers, text_encoders, unet | |
def load_checkpoint_model_xl( | |
checkpoint_path: str, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: | |
pipe = StableDiffusionXLPipeline.from_single_file( | |
checkpoint_path, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
unet = pipe.unet | |
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] | |
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
if len(text_encoders) == 2: | |
text_encoders[1].pad_token_id = 0 | |
del pipe | |
return tokenizers, text_encoders, unet | |
def load_models_xl( | |
pretrained_model_name_or_path: str, | |
scheduler_name: AVAILABLE_SCHEDULERS, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> tuple[ | |
list[CLIPTokenizer], | |
list[SDXL_TEXT_ENCODER_TYPE], | |
UNet2DConditionModel, | |
SchedulerMixin, | |
]: | |
if pretrained_model_name_or_path.endswith( | |
".ckpt" | |
) or pretrained_model_name_or_path.endswith(".safetensors"): | |
( | |
tokenizers, | |
text_encoders, | |
unet, | |
) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) | |
else: # diffusers | |
( | |
tokenizers, | |
text_encoders, | |
unet, | |
) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) | |
scheduler = create_noise_scheduler(scheduler_name) | |
return tokenizers, text_encoders, unet, scheduler | |
def create_noise_scheduler( | |
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", | |
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", | |
) -> SchedulerMixin: | |
# ๆญฃ็ดใใฉใใใใใฎใใใใใชใใๅ ใฎๅฎ่ฃ ใ ใจDDIMใจDDPMใจLMSใ้ธในใใฎใ ใใฉใใฉใใใใใฎใใใใใฌใ | |
name = scheduler_name.lower().replace(" ", "_") | |
if name == "ddim": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim | |
scheduler = DDIMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
clip_sample=False, | |
prediction_type=prediction_type, # ใใใงใใใฎ๏ผ | |
) | |
elif name == "ddpm": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm | |
scheduler = DDPMScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
clip_sample=False, | |
prediction_type=prediction_type, | |
) | |
elif name == "lms": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete | |
scheduler = LMSDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
prediction_type=prediction_type, | |
) | |
elif name == "euler_a": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral | |
scheduler = EulerAncestralDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
prediction_type=prediction_type, | |
) | |
else: | |
raise ValueError(f"Unknown scheduler name: {name}") | |
return scheduler | |