xinjie.wang
update
011981e
raw
history blame
5.93 kB
import logging
import os
import random
from typing import List, Tuple
import fire
import numpy as np
import torch
from diffusers.utils import make_image_grid
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
StableDiffusionXLControlNetImg2ImgPipeline,
)
from PIL import Image, ImageEnhance, ImageFilter
from torchvision import transforms
from asset3d_gen.data.datasets import Asset3dGenDataset
from asset3d_gen.models.texture_model import build_texture_gen_pipe
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_init_noise_image(image: Image.Image) -> Image.Image:
blurred_image = image.convert("L").filter(
ImageFilter.GaussianBlur(radius=3)
)
enhancer = ImageEnhance.Contrast(blurred_image)
image_decreased_contrast = enhancer.enhance(factor=0.5)
return image_decreased_contrast
def infer_pipe(
index_file: str,
controlnet_ckpt: str = None,
uid: str = None,
prompt: str = None,
controlnet_cond_scale: float = 0.4,
control_guidance_end: float = 0.9,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 10,
ip_adapt_scale: float = 0,
ip_img_path: str = None,
sub_idxs: List[List[int]] = None,
num_images_per_prompt: int = 3, # increase if want similar images.
device: str = "cuda",
save_dir: str = "infer_vis",
seed: int = None,
target_hw: tuple[int, int] = (512, 512),
pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None,
) -> str:
# sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image.
if sub_idxs is None:
sub_idxs = [[random.randint(0, 5)]] # 6 views.
target_hw = [2 * size for size in target_hw]
transform_list = [
transforms.Resize(
target_hw, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(target_hw),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
image_transform = transforms.Compose(transform_list)
control_transform = transforms.Compose(transform_list[:-1])
grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]))
dataset = Asset3dGenDataset(
index_file, target_hw=grid_hw, sub_idxs=sub_idxs
)
if uid is None:
uid = random.choice(list(dataset.meta_info.keys()))
if prompt is None:
prompt = dataset.meta_info[uid]["capture"]
if isinstance(prompt, List) or isinstance(prompt, Tuple):
prompt = ", ".join(map(str, prompt))
# prompt += "high quality, ultra-clear, high resolution, best quality, 4k"
# prompt += "高品质,清晰,细节"
prompt += ", high quality, high resolution, best quality"
# prompt += ", with diffuse lighting, showing no reflections."
logger.info(f"Inference with prompt: {prompt}")
negative_prompt = (
"nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯,高光,镜面反射"
)
control_image = dataset.fetch_sample_grid_images(
uid,
attrs=["image_view_normal", "image_position", "image_mask"],
sub_idxs=sub_idxs,
transform=control_transform,
)
color_image = dataset.fetch_sample_grid_images(
uid,
attrs=["image_color"],
sub_idxs=sub_idxs,
transform=image_transform,
)
normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item(
control_image,
color_image,
save_dir=save_dir,
)
if pipeline is None:
pipeline = build_texture_gen_pipe(
base_ckpt_dir="./weights",
controlnet_ckpt=controlnet_ckpt,
ip_adapt_scale=ip_adapt_scale,
device=device,
)
if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0:
ip_image = Image.open(ip_img_path).convert("RGB")
ip_image = ip_image.resize(target_hw[::-1])
ip_image = [ip_image]
pipeline.set_ip_adapter_scale([ip_adapt_scale])
else:
ip_image = None
generator = None
if seed is not None:
generator = torch.Generator(device).manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
init_image = get_init_noise_image(normal_pil)
# init_image = get_init_noise_image(color_pil)
images = []
row_num, col_num = 2, 3
img_save_paths = []
while len(images) < col_num:
image = pipeline(
prompt=prompt,
image=init_image,
controlnet_conditioning_scale=controlnet_cond_scale,
control_guidance_end=control_guidance_end,
strength=strength,
control_image=control_image[None, ...],
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
ip_adapter_image=ip_image,
generator=generator,
).images
images.extend(image)
grid_image = [normal_pil, position_pil, color_pil] + images[:col_num]
# save_dir = os.path.join(save_dir, uid)
os.makedirs(save_dir, exist_ok=True)
for idx in range(col_num):
rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil))
img_save_path = os.path.join(save_dir, f"color_sample{idx}.png")
rgba_image.save(img_save_path)
img_save_paths.append(img_save_path)
sub_idxs = "_".join(
[str(item) for sublist in sub_idxs for item in sublist]
)
save_path = os.path.join(
save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg"
)
make_image_grid(grid_image, row_num, col_num).save(save_path)
logger.info(f"Visualize in {save_path}")
return img_save_paths
def entrypoint() -> None:
fire.Fire(infer_pipe)
if __name__ == "__main__":
entrypoint()