File size: 9,505 Bytes
21c4e64
 
 
 
 
 
 
0f432df
21c4e64
e618667
 
ea60d75
4fb8c01
c06c03d
21c4e64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b601d28
c54a4cd
21c4e64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c54a4cd
21c4e64
 
 
20167fb
21c4e64
 
c54a4cd
21c4e64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e618667
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import os
from PIL import Image
import subprocess
from gradio_model4dgs import Model4DGS
import numpy
import hashlib
import shlex

import spaces


subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
# subprocess.run(shlex.split("pip install xformers==0.0.23 --no-deps --index-url https://download.pytorch.org/whl/cu118"))

from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors")

js_func = """
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') !== 'light') {
        url.searchParams.set('__theme', 'light');
        window.location.href = url.href;
    }
}
"""

# check if there is a picture uploaded or selected
def check_img_input(control_image):
    if control_image is None:
        raise gr.Error("Please select or upload an input image")

# check if there is a picture uploaded or selected
def check_video_input(image_block: Image.Image):
    img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
    if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
        raise gr.Error("Please generate a video first")


@spaces.GPU()
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
    if not os.path.exists('tmp_data'):
        os.makedirs('tmp_data')
    img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
    if preprocess_chk:
        # save image to a designated path
        image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))

        # preprocess image
        print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
        subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
    else:
        image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))

    # stage 1
    subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
    subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
    # return [os.path.join('logs', 'tmp_rgba_model.ply')]
    return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')

@spaces.GPU(duration=120)
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
    img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()

    # stage 2
    subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
    # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
    image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
    # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
    return [image_dir+f'/{t:03d}.ply' for t in range(28)]


if __name__ == "__main__":
    _TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting'''

    _DESCRIPTION = '''
    <div>
    <a style="display:inline-block" href="https://jiawei-ren.github.io/projects/dreamgaussian4d/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
    <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2312.17142"><img src="https://img.shields.io/badge/2312.17142-f9f7f7?logo="></a>
    <a style="display:inline-block; margin-left: .5em" href='https://github.com/jiawei-ren/dreamgaussian4d'><img src='https://img.shields.io/github/stars/jiawei-ren/dreamgaussian4d?style=social'/></a>
    </div>
    We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting. 
    '''
    _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**."

    # load images in 'data' folder as examples
    example_folder = os.path.join(os.path.dirname(__file__), 'data')
    example_fns = os.listdir(example_folder)
    example_fns.sort()
    examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]

    # Compose demo layout & data flow
    with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown('# ' + _TITLE)
        gr.Markdown(_DESCRIPTION)

        # Image-to-3D
        with gr.Row(variant='panel'):
            with gr.Column(scale=4):
                image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')

                # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
                seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
                gr.Markdown(
                    "random seed for video generation.")

                preprocess_chk = gr.Checkbox(True,
                                             label='Preprocess image automatically (remove background and recenter object)')

                gr.Examples(
                    examples=examples_full,  # NOTE: elements must match inputs list!
                    inputs=[image_block],
                    outputs=[image_block],
                    cache_examples=False,
                    label='Examples (click one of the images below to start)',
                    examples_per_page=40
                )
                img_run_btn = gr.Button("Generate Video")
                fourd_run_btn = gr.Button("Generate 4D")
                img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)

            with gr.Column(scale=5):
                obj3d = gr.Video(label="video",height=290)
                obj4d = Model4DGS(label="4D Model", height=500, fps=14)

            img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
                                                                                          inputs=[image_block,
                                                                                                  preprocess_chk,
                                                                                                  seed_slider],
                                                                                          outputs=[
                                                                                              obj3d])
            fourd_run_btn.click(check_video_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])

    # demo.queue().launch(share=True)
    demo.queue(max_size=10)  # <-- Sets up a queue with default parameters
    demo.launch()