""" ----------------------------------------------------------------------------- Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary rights in and to this software, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from NVIDIA CORPORATION is strictly prohibited. ----------------------------------------------------------------------------- """ import argparse import glob import importlib import os from datetime import datetime import cv2 import kiui import numpy as np import rembg import torch import trimesh from flow.model import Model from flow.utils import get_random_color, recenter_foreground from vae.utils import postprocess_mesh # PYTHONPATH=. python flow/scripts/infer.py parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, help="config file path", default="flow.configs.big_parts_strict_pvae", ) parser.add_argument( "--ckpt_path", type=str, help="checkpoint path", default="pretrained/flow.pt", ) parser.add_argument("--input", type=str, help="input directory", default="assets/images/") parser.add_argument("--limit", type=int, help="limit number of images", default=-1) parser.add_argument("--output_dir", type=str, help="output directory", default="output/") parser.add_argument("--grid_res", type=int, help="grid resolution", default=384) parser.add_argument("--num_steps", type=int, help="number of cfg steps", default=30) parser.add_argument("--cfg_scale", type=float, help="cfg scale", default=7.0) parser.add_argument("--num_repeats", type=int, help="number of repeats per image", default=1) parser.add_argument("--seed", type=int, help="seed", default=42) args = parser.parse_args() TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32) bg_remover = rembg.new_session() def preprocess_image(path): input_image = kiui.read_image(path, mode="uint8", order="RGBA") # bg removal if there is no alpha channel if input_image.shape[-1] == 3: input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] mask = input_image[..., -1] > 0 image = recenter_foreground(input_image, mask, border_ratio=0.1) image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR) image = image.astype(np.float32) / 255.0 image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background return image print(f"Loading checkpoint from {args.ckpt_path}") ckpt_dict = torch.load(args.ckpt_path, weights_only=True) # delete all keys other than model if "model" in ckpt_dict: ckpt_dict = ckpt_dict["model"] # instantiate model print(f"Instantiating model from {args.config}") model_config = importlib.import_module(args.config).make_config() model = Model(model_config).eval().cuda().bfloat16() # load weight print(f"Loading weights from {args.ckpt_path}") model.load_state_dict(ckpt_dict, strict=True) # output folder timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") workspace = os.path.join(args.output_dir, "flow_" + args.config.split(".")[-1] + "_" + timestamp) if not os.path.exists(workspace): os.makedirs(workspace) else: os.system(f"rm {workspace}/*") print(f"Output directory: {workspace}") # load test images if os.path.isdir(args.input): paths = glob.glob(os.path.join(args.input, "*")) paths = sorted(paths) if args.limit > 0: paths = paths[: args.limit] else: # single file paths = [args.input] for path in paths: name = os.path.splitext(os.path.basename(path))[0] print(f"Processing {name}") image = preprocess_image(path) kiui.write_image(os.path.join(workspace, name + ".jpg"), image) image = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda() # run model data = {"cond_images": image} for i in range(args.num_repeats): kiui.seed_everything(args.seed + i) with torch.inference_mode(): results = model(data, num_steps=args.num_steps, cfg_scale=args.cfg_scale) latent = results["latent"] # kiui.lo(latent) # query mesh if model.config.use_parts: data_part0 = {"latent": latent[:, : model.config.latent_size, :]} data_part1 = {"latent": latent[:, model.config.latent_size :, :]} with torch.inference_mode(): results_part0 = model.vae(data_part0, resolution=args.grid_res) results_part1 = model.vae(data_part1, resolution=args.grid_res) vertices, faces = results_part0["meshes"][0] mesh_part0 = trimesh.Trimesh(vertices, faces) mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T mesh_part0 = postprocess_mesh(mesh_part0, 5e4) parts = mesh_part0.split(only_watertight=False) vertices, faces = results_part1["meshes"][0] mesh_part1 = trimesh.Trimesh(vertices, faces) mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T mesh_part1 = postprocess_mesh(mesh_part1, 5e4) parts.extend(mesh_part1.split(only_watertight=False)) # split connected components and assign different colors for j, part in enumerate(parts): # each component uses a random color part.visual.vertex_colors = get_random_color(j, use_float=True) mesh = trimesh.Scene(parts) # export the whole mesh mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb")) # export each part for j, part in enumerate(parts): part.export(os.path.join(workspace, name + "_" + str(i) + "_part" + str(j) + ".glb")) # export dual volumes mesh_part0.export(os.path.join(workspace, name + "_" + str(i) + "_vol0.glb")) mesh_part1.export(os.path.join(workspace, name + "_" + str(i) + "_vol1.glb")) else: data = {"latent": latent} with torch.inference_mode(): results = model.vae(data, resolution=args.grid_res) vertices, faces = results["meshes"][0] mesh = trimesh.Trimesh(vertices, faces) mesh = postprocess_mesh(mesh, 5e4) # kiui.lo(mesh.vertices, mesh.faces) mesh.vertices = mesh.vertices @ TRIMESH_GLB_EXPORT.T mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))