File size: 4,122 Bytes
17a538f
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17a538f
fd5e0f7
ff37255
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff37255
 
4122069
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
ff37255
 
d47a7e5
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import spaces
import os
import torch
import gradio as gr
from PIL import Image
from pipe.cfgs import load_cfg
from pipe.c2f_recons import Pipeline
from ops.gs.basic import Gaussian_Scene
from datetime import datetime

cfg = load_cfg(f'pipe/cfgs/basic.yaml')
vistadream = Pipeline(cfg)

from ops.visual_check import Check
checkor = Check()

def get_temp_path():
    if not os.path.exists('data/gradio_temp'):os.makedirs('data/gradio_temp')
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = f"data/gradio_temp/{timestamp}/"
    return output_path

@spaces.GPU(duration=120)
def scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps):
    torch.cuda.init()
    # coarse
    vistadream.scene = Gaussian_Scene(cfg)
    # for trajectory genearation
    vistadream.traj_type = 'spiral'
    vistadream.scene.traj_type = 'spiral'
    vistadream.n_sample = num_coarse_views
    # for scene generation
    vistadream.opt_iters_per_frame = 512
    vistadream.outpaint_extend_times = 0.45 #outpaint_extend_times
    vistadream.outpaint_selections = ['Left','Right','Top','Bottom']
    # for scene refinement
    vistadream.mcs_n_view = num_mcs_views
    vistadream.mcs_rect_w = mcs_rect_w
    vistadream.mcs_iterations = mcs_steps
    # coarse scene
    vistadream._coarse_scene(rgb)
    torch.cuda.empty_cache()

@spaces.GPU(duration=120)
def scene_refinement():
    # refinement
    vistadream._MCS_Refinement()
    output_path = get_temp_path()
    torch.cuda.empty_cache()
    torch.save(vistadream.scene,output_path+'scene.pth')
    return output_path

def render_video(output_path):
    scene = vistadream.scene
    vistadream.checkor._render_video(scene,save_dir=output_path+'.')
    return output_path+'video_rgb.mp4',output_path+'video_dpt.mp4'

def process(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps):
    scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps)
    path = scene_refinement()
    rgb.save(output_path+'input.png')
    return render_video(path)

with gr.Blocks(analytics_enabled=False) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## VistaDream")
        gr.Markdown("### Sampling multiview consistent images for single-view scene reconstruction")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href="https://github.com/WHU-USI3DV/VistaDream">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a> 
            <a href="https://vistadream-project-page.github.io/">
                <img src='https://img.shields.io/badge/Project-Page-green'>
            </a>
			<a href="https://arxiv.org/abs/2410.16892">
                <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
            </a>
        </div>
        """)

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil")
                run_button = gr.Button("Run")
                with gr.Accordion("Advanced options", open=False):
                    num_coarse_views = gr.Slider(label="Coarse-Expand", minimum=5, maximum=25, value=10, step=1)
                    num_mcs_views = gr.Slider(label="MCS Optimization Views", minimum=4, maximum=10, value=8, step=1)
                    mcs_rect_w = gr.Slider(label="MCS Rectification Weight", minimum=0.3, maximum=0.8, value=0.7, step=0.1)
                    mcs_steps = gr.Slider(label="MCS Steps", minimum=8, maximum=15, value=10, step=1)
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        rgb_video = gr.Video("Output RGB renderings")
                    with gr.Column():
                        dpt_video = gr.Video("Output DPT renderings")
                examples = gr.Examples(
                examples = [
                    ],
                    inputs=[input_image,rgb_video,dpt_video]
                )
    ips = [input_image,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps]
    run_button.click(fn=process, inputs=ips, outputs=[rgb_video,dpt_video])

demo.launch(server_name='0.0.0.0')