|
|
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import yaml |
|
import uuid |
|
from typing import Union, Any, Dict |
|
from einops import rearrange |
|
from PIL import Image |
|
|
|
from pipeline.utils import logger, TMP_DIR, OUT_DIR |
|
from pipeline.utils import lrm_reconstruct, isomer_reconstruct |
|
|
|
import torch |
|
import torchvision |
|
|
|
|
|
from omegaconf import OmegaConf |
|
from models.lrm.utils.train_util import instantiate_from_config |
|
from models.lrm.utils.render_utils import rotate_x, rotate_y |
|
from utils.tool import get_background |
|
|
|
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
|
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline, FluxImg2ImgPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler |
|
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel |
|
|
|
|
|
def init_wrapper_from_config(config_path): |
|
with open(config_path, 'r') as config_file: |
|
config_ = yaml.load(config_file, yaml.FullLoader) |
|
|
|
|
|
logger.info('==> Loading Flux model ...') |
|
flux_device = config_['flux'].get('device', 'cpu') |
|
flux_base_model_pth = config_['flux'].get('base_model', None) |
|
flux_controlnet_pth = config_['flux'].get('controlnet', None) |
|
flux_lora_pth = config_['flux'].get('lora', None) |
|
|
|
|
|
if flux_controlnet_pth is not None: |
|
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth) |
|
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], \ |
|
torch_dtype=torch.bfloat16) |
|
else: |
|
flux_pipe = FluxImg2ImgPipeline(flux_base_model_pth, torch_dtype=torch.bfloat16) |
|
|
|
|
|
flux_pipe.load_lora_weights(flux_lora_pth) |
|
flux_pipe.to(device=flux_device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info('==> Loading multiview diffusion model ...') |
|
multiview_device = config_['multiview'].get('device', 'cpu') |
|
multiview_pipeline = DiffusionPipeline.from_pretrained( |
|
config_['multiview']['base_model'], |
|
custom_pipeline=config_['multiview']['custom_pipeline'], |
|
torch_dtype=torch.float16, |
|
) |
|
multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
multiview_pipeline.scheduler.config, timestep_spacing='trailing' |
|
) |
|
|
|
unet_ckpt_path = config_['multiview'].get('unet', None) |
|
if unet_ckpt_path is not None: |
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict'] |
|
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')} |
|
multiview_pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
|
multiview_pipeline.to(multiview_device) |
|
|
|
|
|
logger.info('==> Loading caption model ...') |
|
caption_device = config_['caption'].get('device', 'cpu') |
|
caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \ |
|
torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device) |
|
caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True) |
|
|
|
|
|
logger.info('==> Loading reconstruction model ...') |
|
recon_device = config_['reconstruction'].get('device', 'cpu') |
|
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config']) |
|
recon_model = instantiate_from_config(recon_model_config.model_config) |
|
|
|
state_dict = torch.load(config_['reconstruction']['base_model'], map_location='cpu')['state_dict'] |
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
|
recon_model.load_state_dict(state_dict, strict=True) |
|
recon_model.to(recon_device) |
|
recon_model.init_flexicubes_geometry(recon_device, fovy=50.0) |
|
recon_model.eval() |
|
|
|
return kiss3d_wrapper( |
|
config = config_, |
|
flux_pipeline = flux_pipe, |
|
multiview_pipeline = multiview_pipeline, |
|
caption_processor = caption_processor, |
|
caption_model = caption_model, |
|
reconstruction_model_config = recon_model_config, |
|
reconstruction_model = recon_model, |
|
) |
|
|
|
class kiss3d_wrapper(object): |
|
def __init__(self, |
|
config: Dict, |
|
flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline], |
|
multiview_pipeline: DiffusionPipeline, |
|
caption_processor: AutoProcessor, |
|
caption_model: AutoModelForCausalLM, |
|
reconstruction_model_config: Any, |
|
reconstruction_model: Any, |
|
): |
|
self.config = config |
|
self.flux_pipeline = flux_pipeline |
|
self.multiview_pipeline = multiview_pipeline |
|
self.caption_model = caption_model |
|
self.caption_processor = caption_processor |
|
self.recon_model_config = reconstruction_model_config |
|
self.recon_model = reconstruction_model |
|
|
|
self.renew_uuid() |
|
|
|
def renew_uuid(self): |
|
self.uuid = uuid.uuid4() |
|
|
|
def context(self): |
|
if self.config['use_zero_gpu']: |
|
import spaces |
|
return spaces.GPU() |
|
else: |
|
return torch.no_grad() |
|
|
|
def get_image_caption(self, image): |
|
""" |
|
image: PIL image or path of PIL image |
|
""" |
|
torch_dtype = torch.bfloat16 |
|
caption_device = self.config['caption'].get('device', 'cpu') |
|
|
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
elif isinstance(image, Image): |
|
image = image.convert("RGB") |
|
else: |
|
raise NotImplementedError('unexpected image type') |
|
|
|
prompt = "<MORE_DETAILED_CAPTION>" |
|
inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype) |
|
|
|
generated_ids = self.caption_model.generate( |
|
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 |
|
) |
|
|
|
generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = self.caption_processor.post_process_generation( |
|
generated_text, task=prompt, image_size=(image.width, image.height) |
|
) |
|
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "") |
|
return caption_text |
|
|
|
def generate_multiview(self, image): |
|
with self.context(): |
|
mv_image = self.multiview_pipeline(image, |
|
num_inference_steps=self.config['multiview']['num_inference_steps'], |
|
width=512*2, height=512*2).images[0] |
|
return mv_image |
|
|
|
def reconstruct_from_multiview(self, mv_image): |
|
""" |
|
mv_image: PIL.Image |
|
""" |
|
recon_device = self.config['reconstruction'].get('device', 'cpu') |
|
|
|
rgb_multi_view = np.asarray(mv_image, dtype=np.float32) / 255.0 |
|
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() |
|
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2).unsqueeze(0).to(recon_device) |
|
|
|
with self.context(): |
|
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \ |
|
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config, |
|
rgb_multi_view, name=self.uuid) |
|
|
|
return vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo |
|
|
|
def generate_reference_3D_bundle_image_zero123(self, image, save_intermediate_results=True): |
|
""" |
|
input: image, PIL.Image |
|
return: ref_3D_bundle_image, Tensor of shape (1, 3, 1024, 2048) |
|
""" |
|
mv_image = self.generate_multiview(image) |
|
|
|
if save_intermediate_results: |
|
mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png')) |
|
|
|
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image) |
|
|
|
ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) |
|
|
|
if save_intermediate_results: |
|
save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png') |
|
torchvision.utils.save_image(ref_3D_bundle_image, save_path) |
|
|
|
logger.info(f"Save reference 3D bundle image to {save_path}") |
|
|
|
return ref_3D_bundle_image, save_path |
|
|
|
return ref_3D_bundle_image |
|
|
|
def generate_3d_bundle_image_controlnet(self, |
|
prompt, |
|
image=None, |
|
strength=1.0, |
|
control_image=[], |
|
control_mode=[], |
|
control_guidance_start=None, |
|
control_guidance_end=None, |
|
controlnet_conditioning_scale=None, |
|
lora_scale=1.0, |
|
save_intermediate_results=True, |
|
**kwargs): |
|
control_mode_dict = { |
|
'canny': 0, |
|
'tile': 1, |
|
'depth': 2, |
|
'blur': 3, |
|
'pose': 4, |
|
'gray': 5, |
|
'lq': 6, |
|
} |
|
|
|
flux_device = self.config['flux'].get('device', 'cpu') |
|
seed = self.config['flux'].get('seed', 0) |
|
|
|
generator = torch.Generator(device=flux_device).manual_seed(seed) |
|
|
|
hparam_dict = { |
|
'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]), |
|
'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device), |
|
'strength': strength, |
|
'num_inference_steps': 30, |
|
'guidance_scale': 3.5, |
|
'num_images_per_prompt': 1, |
|
'width': 2048, |
|
'height': 1024, |
|
'output_type': 'np', |
|
'generator': generator, |
|
'joint_attention_kwargs': {"scale": lora_scale} |
|
} |
|
hparam_dict.update(kwargs) |
|
|
|
|
|
if len(control_image) > 0: |
|
assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline) |
|
assert len(control_mode) == len(control_image) |
|
|
|
flux_ctrl_net = self.flux_pipeline.controlnet.nets[0] |
|
self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for i in range(len(control_image))]) |
|
|
|
ctrl_hparams = { |
|
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode], |
|
'control_image': control_image, |
|
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))], |
|
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))], |
|
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))], |
|
} |
|
|
|
hparam_dict.update(ctrl_hparams) |
|
|
|
with self.context(): |
|
gen_3d_bundle_image = self.flux_pipeline(**hparam_dict).images |
|
|
|
gen_3d_bundle_image_ = torch.from_numpy(gen_3d_bundle_image).squeeze(0).permute(2, 0, 1).contiguous().float() |
|
|
|
if save_intermediate_results: |
|
save_path = os.path.join(TMP_DIR, f'{self.uuid}_gen_3d_bundle_image.png') |
|
torchvision.utils.save_image(gen_3d_bundle_image_, save_path) |
|
logger.info(f"Save generated 3D bundle image to {save_path}") |
|
return gen_3d_bundle_image_, save_path |
|
|
|
return gen_3d_bundle_image_ |
|
|
|
|
|
def generate_3d_bundle_image_text(self, |
|
prompt, |
|
image=None, |
|
strength=1.0, |
|
lora_scale=1.0, |
|
num_inference_steps=30, |
|
save_intermediate_results=True, |
|
**kwargs): |
|
|
|
""" |
|
return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.] |
|
""" |
|
|
|
if isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline): |
|
flux_pipeline = FluxImg2ImgPipeline( |
|
scheduler = self.flux_pipeline.scheduler, |
|
vae = self.flux_pipeline.vae, |
|
text_encoder = self.flux_pipeline.text_encoder, |
|
tokenizer = self.flux_pipeline.tokenizer, |
|
text_encoder_2 = self.flux_pipeline.text_encoder_2, |
|
tokenizer_2 = self.flux_pipeline.tokenizer_2, |
|
transformer = self.flux_pipeline.transformer |
|
) |
|
else: |
|
flux_pipeline = self.flux_pipeline |
|
|
|
flux_device = self.config['flux'].get('device', 'cpu') |
|
seed = self.config['flux'].get('seed', 0) |
|
|
|
generator = torch.Generator(device=flux_device).manual_seed(seed) |
|
|
|
hparam_dict = { |
|
'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]), |
|
'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device), |
|
'strength': strength, |
|
'num_inference_steps': num_inference_steps, |
|
'guidance_scale': 3.5, |
|
'num_images_per_prompt': 1, |
|
'width': 2048, |
|
'height': 1024, |
|
'output_type': 'np', |
|
'generator': generator, |
|
'joint_attention_kwargs': {"scale": lora_scale} |
|
} |
|
hparam_dict.update(kwargs) |
|
|
|
with self.context(): |
|
gen_3d_bundle_image = flux_pipeline(**hparam_dict).images |
|
|
|
gen_3d_bundle_image_ = torch.from_numpy(gen_3d_bundle_image).squeeze(0).permute(2, 0, 1).contiguous().float() |
|
|
|
if save_intermediate_results: |
|
save_path = os.path.join(TMP_DIR, f'{self.uuid}_gen_3d_bundle_image.png') |
|
torchvision.utils.save_image(gen_3d_bundle_image_, save_path) |
|
logger.info(f"Save generated 3D bundle image to {save_path}") |
|
return gen_3d_bundle_image_, save_path |
|
|
|
return gen_3d_bundle_image_ |
|
|
|
def reconstruct_3d_bundle_image(self, image, save_intermediate_results=True): |
|
""" |
|
image: torch.Tensor, range [0., 1.], (3, 1024, 2048) |
|
""" |
|
recon_device = self.config['reconstruction'].get('device', 'cpu') |
|
|
|
|
|
images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) |
|
rgb_multi_view, normal_multi_view = images.chunk(2, dim=0) |
|
multi_view_mask = get_background(normal_multi_view).to(recon_device) |
|
rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask) |
|
|
|
with self.context(): |
|
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \ |
|
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config, |
|
rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid, |
|
input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results, |
|
render_azimuths=[0, 90, 180, 270]) |
|
|
|
if save_intermediate_results: |
|
recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) |
|
torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{k3d_wrapper.uuid})_lrm_recon_3d_bundle_image.png')) |
|
|
|
recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj") |
|
|
|
return isomer_reconstruct(rgb_multi_view=rgb_multi_view, |
|
normal_multi_view=normal_multi_view, |
|
multi_view_mask=multi_view_mask, |
|
vertices=vertices, |
|
faces=faces, |
|
save_path=recon_mesh_path) |
|
|
|
|
|
def run_text_to_3d(k3d_wrapper, |
|
prompt, |
|
init_image_path=None): |
|
|
|
|
|
|
|
k3d_wrapper.renew_uuid() |
|
|
|
|
|
init_image = None |
|
if init_image_path is not None: |
|
init_image = Image.open(init_image_path) |
|
|
|
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt, |
|
image=init_image, |
|
strength=1.0, |
|
save_intermediate_results=True) |
|
|
|
|
|
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False) |
|
|
|
return gen_save_path, recon_mesh_path |
|
|
|
def run_image_to_3d(k3d_wrapper, init_image_path): |
|
|
|
|
|
|
|
k3d_wrapper.renew_uuid() |
|
|
|
|
|
input_image = Image.open(init_image_path) |
|
reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image) |
|
caption = k3d_wrapper.get_image_caption(input_image) |
|
|
|
|
|
import pdb |
|
pdb.set_trace() |
|
|
|
|
|
if __name__ == "__main__": |
|
k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/Kiss3DGen/pipeline/pipeline_config/default.yaml') |
|
|
|
|
|
|
|
|
|
|
|
run_image_to_3d(k3d_wrapper, '/hpc2hdd/home/jlin695/code/Kiss3DGen/examples/蓝色小怪物.webp') |
|
|
|
|
|
|