Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import torch | |
from diffusers import ( | |
AutoencoderKL, | |
EulerDiscreteScheduler, | |
UNet2DConditionModel, | |
) | |
from kolors.models.modeling_chatglm import ChatGLMModel | |
from kolors.models.tokenization_chatglm import ChatGLMTokenizer | |
from kolors.models.unet_2d_condition import ( | |
UNet2DConditionModel as UNet2DConditionModelIP, | |
) | |
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( | |
StableDiffusionXLPipeline, | |
) | |
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa | |
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, | |
) | |
from PIL import Image | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
__all__ = [ | |
"build_text2img_ip_pipeline", | |
"build_text2img_pipeline", | |
"text2img_gen", | |
] | |
def build_text2img_ip_pipeline( | |
ckpt_dir: str, | |
ref_scale: float, | |
device: str = "cuda", | |
) -> StableDiffusionXLPipelineIP: | |
text_encoder = ChatGLMModel.from_pretrained( | |
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 | |
).half() | |
tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") | |
vae = AutoencoderKL.from_pretrained( | |
f"{ckpt_dir}/vae", revision=None | |
).half() | |
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") | |
unet = UNet2DConditionModelIP.from_pretrained( | |
f"{ckpt_dir}/unet", revision=None | |
).half() | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder", | |
ignore_mismatched_sizes=True, | |
).to(dtype=torch.float16) | |
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336) | |
pipe = StableDiffusionXLPipelineIP( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
image_encoder=image_encoder, | |
feature_extractor=clip_image_processor, | |
force_zeros_for_empty_prompt=False, | |
) | |
if hasattr(pipe.unet, "encoder_hid_proj"): | |
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj | |
pipe.load_ip_adapter( | |
f"{ckpt_dir}/../Kolors-IP-Adapter-Plus", | |
subfolder="", | |
weight_name=["ip_adapter_plus_general.bin"], | |
) | |
pipe.set_ip_adapter_scale([ref_scale]) | |
pipe = pipe.to(device) | |
# pipe.enable_model_cpu_offload() | |
# pipe.enable_xformers_memory_efficient_attention() | |
# pipe.enable_vae_slicing() | |
return pipe | |
def build_text2img_pipeline( | |
ckpt_dir: str, | |
device: str = "cuda", | |
) -> StableDiffusionXLPipeline: | |
text_encoder = ChatGLMModel.from_pretrained( | |
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 | |
).half() | |
tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") | |
vae = AutoencoderKL.from_pretrained( | |
f"{ckpt_dir}/vae", revision=None | |
).half() | |
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") | |
unet = UNet2DConditionModel.from_pretrained( | |
f"{ckpt_dir}/unet", revision=None | |
).half() | |
pipe = StableDiffusionXLPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
force_zeros_for_empty_prompt=False, | |
) | |
pipe = pipe.to(device) | |
# pipe.enable_model_cpu_offload() | |
# pipe.enable_xformers_memory_efficient_attention() | |
return pipe | |
def text2img_gen( | |
prompt: str, | |
n_sample: int, | |
guidance_scale: float, | |
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP, | |
ip_image: Image.Image | str = None, | |
image_wh: tuple[int, int] = [1024, 1024], | |
infer_step: int = 50, | |
ip_image_size: int = 512, | |
) -> list[Image.Image]: | |
prompt = "Single " + prompt + ", in the center of the image" | |
prompt += ", high quality, high resolution, best quality, white background, 3D style," # noqa | |
logger.info(f"Processing prompt: {prompt}") | |
kwargs = dict( | |
prompt=prompt, | |
height=image_wh[1], | |
width=image_wh[0], | |
num_inference_steps=infer_step, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=n_sample, | |
) | |
if ip_image is not None: | |
if isinstance(ip_image, str): | |
ip_image = Image.open(ip_image) | |
ip_image = ip_image.resize((ip_image_size, ip_image_size)) | |
kwargs.update(ip_adapter_image=[ip_image]) | |
return pipeline(**kwargs).images | |