File size: 2,998 Bytes
55ed985
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4811e40
55ed985
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os

import torch
from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
from kolors.models.controlnet import ControlNetModel
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
    StableDiffusionXLControlNetImg2ImgPipeline,
)
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

__all__ = [
    "build_texture_gen_pipe",
]


def build_texture_gen_pipe(
    base_ckpt_dir: str,
    controlnet_ckpt: str = None,
    ip_adapt_scale: float = 0,
    device: str = "cuda",
) -> DiffusionPipeline:
    tokenizer = ChatGLMTokenizer.from_pretrained(
        f"{base_ckpt_dir}/Kolors/text_encoder"
    )
    text_encoder = ChatGLMModel.from_pretrained(
        f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
    ).half()
    vae = AutoencoderKL.from_pretrained(
        f"{base_ckpt_dir}/Kolors/vae", revision=None
    ).half()
    unet = UNet2DConditionModel.from_pretrained(
        f"{base_ckpt_dir}/Kolors/unet", revision=None
    ).half()
    scheduler = EulerDiscreteScheduler.from_pretrained(
        f"{base_ckpt_dir}/Kolors/scheduler"
    )

    if controlnet_ckpt is None:
        suffix = "geo_cond_mv"
        model_path = snapshot_download(
            repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
        )
        controlnet_ckpt = os.path.join(model_path, suffix)

    controlnet = ControlNetModel.from_pretrained(
        controlnet_ckpt, use_safetensors=True
    ).half()

    # IP-Adapter model
    image_encoder = None
    clip_image_processor = None
    if ip_adapt_scale > 0:
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
            # ignore_mismatched_sizes=True,
        ).to(dtype=torch.float16)
        ip_img_size = 336
        clip_image_processor = CLIPImageProcessor(
            size=ip_img_size, crop_size=ip_img_size
        )

    pipe = StableDiffusionXLControlNetImg2ImgPipeline(
        vae=vae,
        controlnet=controlnet,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        image_encoder=image_encoder,
        feature_extractor=clip_image_processor,
        force_zeros_for_empty_prompt=False,
    )

    if ip_adapt_scale > 0:
        if hasattr(pipe.unet, "encoder_hid_proj"):
            pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
        pipe.load_ip_adapter(
            f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
            subfolder="",
            weight_name=["ip_adapter_plus_general.bin"],
        )
        pipe.set_ip_adapter_scale([ip_adapt_scale])

    pipe = pipe.to(device)
    # pipe.enable_model_cpu_offload()

    return pipe