ashawkey's picture
init
daa6779
raw
history blame
6.64 kB
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""
import argparse
import glob
import importlib
import os
from datetime import datetime
import cv2
import kiui
import numpy as np
import rembg
import torch
import trimesh
from flow.model import Model
from flow.utils import get_random_color, recenter_foreground
from vae.utils import postprocess_mesh
# PYTHONPATH=. python flow/scripts/infer.py
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
help="config file path",
default="flow.configs.big_parts_strict_pvae",
)
parser.add_argument(
"--ckpt_path",
type=str,
help="checkpoint path",
default="pretrained/flow.pt",
)
parser.add_argument("--input", type=str, help="input directory", default="assets/images/")
parser.add_argument("--limit", type=int, help="limit number of images", default=-1)
parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
parser.add_argument("--grid_res", type=int, help="grid resolution", default=384)
parser.add_argument("--num_steps", type=int, help="number of cfg steps", default=30)
parser.add_argument("--cfg_scale", type=float, help="cfg scale", default=7.0)
parser.add_argument("--num_repeats", type=int, help="number of repeats per image", default=1)
parser.add_argument("--seed", type=int, help="seed", default=42)
args = parser.parse_args()
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
bg_remover = rembg.new_session()
def preprocess_image(path):
input_image = kiui.read_image(path, mode="uint8", order="RGBA")
# bg removal if there is no alpha channel
if input_image.shape[-1] == 3:
input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
mask = input_image[..., -1] > 0
image = recenter_foreground(input_image, mask, border_ratio=0.1)
image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32) / 255.0
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
return image
print(f"Loading checkpoint from {args.ckpt_path}")
ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
# delete all keys other than model
if "model" in ckpt_dict:
ckpt_dict = ckpt_dict["model"]
# instantiate model
print(f"Instantiating model from {args.config}")
model_config = importlib.import_module(args.config).make_config()
model = Model(model_config).eval().cuda().bfloat16()
# load weight
print(f"Loading weights from {args.ckpt_path}")
model.load_state_dict(ckpt_dict, strict=True)
# output folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
workspace = os.path.join(args.output_dir, "flow_" + args.config.split(".")[-1] + "_" + timestamp)
if not os.path.exists(workspace):
os.makedirs(workspace)
else:
os.system(f"rm {workspace}/*")
print(f"Output directory: {workspace}")
# load test images
if os.path.isdir(args.input):
paths = glob.glob(os.path.join(args.input, "*"))
paths = sorted(paths)
if args.limit > 0:
paths = paths[: args.limit]
else: # single file
paths = [args.input]
for path in paths:
name = os.path.splitext(os.path.basename(path))[0]
print(f"Processing {name}")
image = preprocess_image(path)
kiui.write_image(os.path.join(workspace, name + ".jpg"), image)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
# run model
data = {"cond_images": image}
for i in range(args.num_repeats):
kiui.seed_everything(args.seed + i)
with torch.inference_mode():
results = model(data, num_steps=args.num_steps, cfg_scale=args.cfg_scale)
latent = results["latent"]
# kiui.lo(latent)
# query mesh
if model.config.use_parts:
data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
data_part1 = {"latent": latent[:, model.config.latent_size :, :]}
with torch.inference_mode():
results_part0 = model.vae(data_part0, resolution=args.grid_res)
results_part1 = model.vae(data_part1, resolution=args.grid_res)
vertices, faces = results_part0["meshes"][0]
mesh_part0 = trimesh.Trimesh(vertices, faces)
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
parts = mesh_part0.split(only_watertight=False)
vertices, faces = results_part1["meshes"][0]
mesh_part1 = trimesh.Trimesh(vertices, faces)
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
parts.extend(mesh_part1.split(only_watertight=False))
# split connected components and assign different colors
for j, part in enumerate(parts):
# each component uses a random color
part.visual.vertex_colors = get_random_color(j, use_float=True)
mesh = trimesh.Scene(parts)
# export the whole mesh
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
# export each part
for j, part in enumerate(parts):
part.export(os.path.join(workspace, name + "_" + str(i) + "_part" + str(j) + ".glb"))
# export dual volumes
mesh_part0.export(os.path.join(workspace, name + "_" + str(i) + "_vol0.glb"))
mesh_part1.export(os.path.join(workspace, name + "_" + str(i) + "_vol1.glb"))
else:
data = {"latent": latent}
with torch.inference_mode():
results = model.vae(data, resolution=args.grid_res)
vertices, faces = results["meshes"][0]
mesh = trimesh.Trimesh(vertices, faces)
mesh = postprocess_mesh(mesh, 5e4)
# kiui.lo(mesh.vertices, mesh.faces)
mesh.vertices = mesh.vertices @ TRIMESH_GLB_EXPORT.T
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))