3DTopia-XL / app.py
FrozenBurning's picture
Update app.py
53e0ff4 verified
import os
import imageio
import numpy as np
os.system("bash install.sh")
from omegaconf import OmegaConf
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import rembg
import gradio as gr
from gradio_litmodel3d import LitModel3D
from dva.io import load_from_config
from dva.ray_marcher import RayMarcher
from dva.visualize import visualize_primvolume, visualize_video_primvolume
from inference import remove_background, resize_foreground, extract_texmesh
from models.diffusion import create_diffusion
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt")
vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt")
GRADIO_PRIM_VIDEO_PATH = 'prim.mp4'
GRADIO_RGB_VIDEO_PATH = 'rgb.mp4'
GRADIO_MAT_VIDEO_PATH = 'mat.mp4'
GRADIO_GLB_PATH = 'pbr_mesh.glb'
CONFIG_PATH = "./configs/inference_dit.yml"
config = OmegaConf.load(CONFIG_PATH)
config.checkpoint_path = ckpt_path
config.model.vae_checkpoint_path = vae_ckpt_path
# model
model = load_from_config(config.model.generator)
state_dict = torch.load(config.checkpoint_path, map_location='cpu')
model.load_state_dict(state_dict['ema'])
vae = load_from_config(config.model.vae)
vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu')
vae.load_state_dict(vae_state_dict['model_state_dict'])
conditioner = load_from_config(config.model.conditioner)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = vae.to(device)
conditioner = conditioner.to(device)
model = model.to(device)
model.eval()
amp = True
precision_dtype = torch.float16
rm = RayMarcher(
256,
256,
**config.rm,
).to(device)
perchannel_norm = False
if "latent_mean" in config.model:
latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device)
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
assert latent_mean.shape[-1] == config.model.generator.in_channels
perchannel_norm = True
latent_nf = config.model.latent_nf
config.diffusion.pop("timestep_respacing")
config.model.pop("vae")
config.model.pop("vae_checkpoint_path")
config.model.pop("conditioner")
config.model.pop("generator")
config.model.pop("latent_nf")
config.model.pop("latent_mean")
config.model.pop("latent_std")
model_primx = load_from_config(config.model)
# load rembg
rembg_session = rembg.new_session()
# background removal function
def background_remove_process(input_image):
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
input_cond_preview_pil = input_image
raw_image = np.array(input_image)
mask = (raw_image[..., -1][..., None] > 0) * 1
raw_image = raw_image[..., :3] * mask
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
return gr.update(interactive=True), input_cond, input_cond_preview_pil
# process function
def process(input_cond, input_num_steps, input_seed=42, input_cfg=6.0):
# seed
torch.manual_seed(input_seed)
os.makedirs(config.output_dir, exist_ok=True)
output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH)
output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH)
output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH)
respacing = "ddim{}".format(input_num_steps)
diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion)
sample_fn = diffusion.ddim_sample_loop_progressive
fwd_fn = model.forward_with_cfg
# text-conditioned
if input_cond is None:
raise NotImplementedError
with torch.no_grad():
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
batch = {}
inf_bs = 1
inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device)
y = conditioner.encoder(input_cond)
model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp)
if input_cfg >= 0:
model_kwargs['cfg_scale'] = input_cfg
for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device):
final_samples = samples
recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
if perchannel_norm:
recon_param = recon_param / latent_nf * latent_std + latent_mean
recon_srt_param = recon_param[:, :, 0:4]
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
recon_feat_param_list = []
# one-by-one to avoid oom
for inf_bidx in range(inf_bs):
if not perchannel_norm:
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / latent_nf)
else:
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
recon_feat_param_list.append(decoded.detach())
recon_feat_param = torch.concat(recon_feat_param_list, dim=0)
# invert normalization
if not perchannel_norm:
recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05
recon_feat_param[:, 0:1, ...] /= 5.
recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2.
recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1)
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
visualize_video_primvolume(config.output_dir, batch, recon_param, 15, rm, device)
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
return output_rgb_video_path, output_prim_video_path, output_mat_video_path, prim_params
def export_mesh(prim_params, uv_unwrap="Faster", remesh="No", mc_resolution=256):
# exporting GLB mesh
output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
if remesh == "No":
config.inference.remesh = False
elif remesh == "Yes":
config.inference.remesh = True
if uv_unwrap == "Faster":
config.inference.fast_unwrap = True
elif uv_unwrap == "Better":
config.inference.fast_unwrap = False
config.inference.mc_resolution = mc_resolution
config.inference.batch_size = 8192
model_primx.load_state_dict(prim_params)
model_primx.to(device)
model_primx.eval()
with torch.no_grad():
model_primx.srt_param[:, 1:4] *= 0.85
extract_texmesh(config.inference, model_primx, config.output_dir, device)
return output_glb_path, gr.update(visible=True), gr.update(interactive=True), gr.update(value="assets/hdri/metro_noord_1k.hdr")
# gradio UI
_TITLE = '''3DTopia-XL: Scaling High-quality 3D Asset Generation via Primitive Diffusion'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://3dtopia.github.io/3DTopia-XL/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
<a style="display:inline-block; margin-left: .5em" href="https://github.com/3DTopia/3DTopia-XL"><img src='https://img.shields.io/github/stars/3DTopia/3DTopia-XL?style=social'/></a>
</div>
* Now we offer 1) **single image** conditioned model, we will release 2) **multiview images** conditioned model and 3) **pure text** conditioned model in the future!
* If you find the output unsatisfying, try using **different seeds** or **more DDIM steps**!
'''
_DEV_DES = '''
* Please refer to our repo for instructions on running gradio demo [locally](https://github.com/3DTopia/3DTopia-XL?tab=readme-ov-file#gradio-demo) or [CLI test](https://github.com/3DTopia/3DTopia-XL?tab=readme-ov-file#cli-test)
'''
block = gr.Blocks(title=_TITLE).queue()
with block:
current_fg_state = gr.State()
prim_param_state = gr.State()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Accordion("For Developers", open=False):
gr.Markdown(_DEV_DES)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
with gr.Row():
# input image
input_image = gr.Image(label="image", type='pil')
# background removal
removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
with gr.Row():
# inference steps
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25, info="Larger for robustness but slower.")
# random seed
input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=0.5, value=6, info="Typically CFG in a range of 4-7")
# random seed
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
with gr.Row():
input_mc_resolution = gr.Radio(choices=[128, 256], label="MC Resolution", value=128, info="Cube resolution for mesh extraction. Larger for better quality but slower.")
input_remesh = gr.Radio(choices=["No", "Yes"], label="Remesh", value="No", info="Remesh or not?")
input_unwrap = gr.Radio(choices=["Faster", "Better"], label="UV Unwrap", value="Better", info="UV unwrapping algorithm. Trade-off between quality and speed.")
# gen button
with gr.Row():
button_gen = gr.Button(value="Generate", interactive=False)
export_glb_btn = gr.Button(value="Export Current GLB", interactive=False)
with gr.Column(scale=1):
with gr.Row():
# final video results
output_rgb_video = gr.Video(label="RGB")
output_prim_video = gr.Video(label="Primitives")
output_mat_video = gr.Video(label="Material")
with gr.Row():
# glb file
output_glb = LitModel3D(
label="3D GLB Model",
visible=True,
clear_color=[0.0, 0.0, 0.0, 0.0],
camera_position=(90, None, None),
tonemapping="aces",
contrast=1.0,
scale=1.0,
)
with gr.Column(visible=False, scale=1.0) as hdr_row:
gr.Markdown("""## HDR Environment Map
Select / Upload an HDR environment map to relight the 3D model.
""")
with gr.Row():
example_hdris = [
os.path.join("assets/hdri", f)
for f in os.listdir("assets/hdri")
]
hdr_illumination_file = gr.File(
label="HDR Envmap", file_types=[".hdr"], file_count="single"
)
hdr_illumination_example = gr.Examples(
examples=example_hdris,
inputs=hdr_illumination_file,
)
hdr_illumination_file.change(
lambda x: gr.update(env_map=x.name if x is not None else None),
inputs=hdr_illumination_file,
outputs=[output_glb],
)
input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, prim_param_state])
prim_param_state.change(export_mesh, inputs=[prim_param_state, input_unwrap, input_remesh, input_mc_resolution], outputs=[output_glb, hdr_row, export_glb_btn, hdr_illumination_file])
export_glb_btn.click(export_mesh, inputs=[prim_param_state, input_unwrap, input_remesh, input_mc_resolution], outputs=[output_glb, hdr_row, export_glb_btn, hdr_illumination_file])
gr.Examples(
examples=[
os.path.join("assets/examples", f)
for f in os.listdir("assets/examples")
],
inputs=[input_image],
outputs=[output_rgb_video, output_prim_video, output_mat_video, prim_param_state],
fn=lambda x: process(input_image=x),
cache_examples=False,
label='Single Image to 3D PBR Asset'
)
block.launch(server_name="0.0.0.0", share=True)