Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,998 Bytes
55ed985 4811e40 55ed985 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
from kolors.models.controlnet import ControlNetModel
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
__all__ = [
"build_texture_gen_pipe",
]
def build_texture_gen_pipe(
base_ckpt_dir: str,
controlnet_ckpt: str = None,
ip_adapt_scale: float = 0,
device: str = "cuda",
) -> DiffusionPipeline:
tokenizer = ChatGLMTokenizer.from_pretrained(
f"{base_ckpt_dir}/Kolors/text_encoder"
)
text_encoder = ChatGLMModel.from_pretrained(
f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
).half()
vae = AutoencoderKL.from_pretrained(
f"{base_ckpt_dir}/Kolors/vae", revision=None
).half()
unet = UNet2DConditionModel.from_pretrained(
f"{base_ckpt_dir}/Kolors/unet", revision=None
).half()
scheduler = EulerDiscreteScheduler.from_pretrained(
f"{base_ckpt_dir}/Kolors/scheduler"
)
if controlnet_ckpt is None:
suffix = "geo_cond_mv"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
controlnet_ckpt = os.path.join(model_path, suffix)
controlnet = ControlNetModel.from_pretrained(
controlnet_ckpt, use_safetensors=True
).half()
# IP-Adapter model
image_encoder = None
clip_image_processor = None
if ip_adapt_scale > 0:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
# ignore_mismatched_sizes=True,
).to(dtype=torch.float16)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(
size=ip_img_size, crop_size=ip_img_size
)
pipe = StableDiffusionXLControlNetImg2ImgPipeline(
vae=vae,
controlnet=controlnet,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False,
)
if ip_adapt_scale > 0:
if hasattr(pipe.unet, "encoder_hid_proj"):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(
f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
subfolder="",
weight_name=["ip_adapter_plus_general.bin"],
)
pipe.set_ip_adapter_scale([ip_adapt_scale])
pipe = pipe.to(device)
# pipe.enable_model_cpu_offload()
return pipe
|