VistaDream / app.py
hpwang's picture
Update app.py
d47a7e5 verified
raw
history blame
3.99 kB
import os
import torch
import gradio as gr
from PIL import Image
from pipe.cfgs import load_cfg
from pipe.c2f_recons import Pipeline
from ops.gs.basic import Gaussian_Scene
from datetime import datetime
cfg = load_cfg(f'pipe/cfgs/basic.yaml')
vistadream = Pipeline(cfg)
from ops.visual_check import Check
checkor = Check()
def get_temp_path():
if not os.path.exists('data/gradio_temp'):os.makedirs('data/gradio_temp')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"data/gradio_temp/{timestamp}/"
return output_path
def scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps):
# coarse
vistadream.scene = Gaussian_Scene(cfg)
# for trajectory genearation
vistadream.traj_type = 'spiral'
vistadream.scene.traj_type = 'spiral'
vistadream.n_sample = num_coarse_views
# for scene generation
vistadream.opt_iters_per_frame = 512
vistadream.outpaint_extend_times = 0.45 #outpaint_extend_times
vistadream.outpaint_selections = ['Left','Right','Top','Bottom']
# for scene refinement
vistadream.mcs_n_view = num_mcs_views
vistadream.mcs_rect_w = mcs_rect_w
vistadream.mcs_iterations = mcs_steps
# coarse scene
vistadream._coarse_scene(rgb)
torch.cuda.empty_cache()
# refinement
vistadream._MCS_Refinement()
output_path = get_temp_path()
torch.cuda.empty_cache()
torch.save(vistadream.scene,output_path+'scene.pth')
return output_path
def render_video(output_path):
scene = vistadream.scene
vistadream.checkor._render_video(scene,save_dir=output_path+'.')
return output_path+'video_rgb.mp4',output_path+'video_dpt.mp4'
def process(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps):
path = scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps)
rgb.save(output_path+'input.png')
return render_video(path)
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## VistaDream")
gr.Markdown("### Sampling multiview consistent images for single-view scene reconstruction")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/WHU-USI3DV/VistaDream">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://vistadream-project-page.github.io/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://arxiv.org/abs/2410.16892">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
num_coarse_views = gr.Slider(label="Coarse-Expand", minimum=5, maximum=25, value=10, step=1)
num_mcs_views = gr.Slider(label="MCS Optimization Views", minimum=4, maximum=10, value=8, step=1)
mcs_rect_w = gr.Slider(label="MCS Rectification Weight", minimum=0.3, maximum=0.8, value=0.7, step=0.1)
mcs_steps = gr.Slider(label="MCS Steps", minimum=8, maximum=15, value=10, step=1)
with gr.Column():
with gr.Row():
with gr.Column():
rgb_video = gr.Video("Output RGB renderings")
with gr.Column():
dpt_video = gr.Video("Output DPT renderings")
examples = gr.Examples(
examples = [
],
inputs=[input_image,rgb_video,dpt_video]
)
ips = [input_image,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps]
run_button.click(fn=process, inputs=ips, outputs=[rgb_video,dpt_video])
demo.launch(server_name='0.0.0.0')