Spaces:
Runtime error
Runtime error
File size: 9,505 Bytes
21c4e64 0f432df 21c4e64 e618667 ea60d75 4fb8c01 c06c03d 21c4e64 b601d28 c54a4cd 21c4e64 c54a4cd 21c4e64 20167fb 21c4e64 c54a4cd 21c4e64 e618667 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import os
from PIL import Image
import subprocess
from gradio_model4dgs import Model4DGS
import numpy
import hashlib
import shlex
import spaces
subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
# subprocess.run(shlex.split("pip install xformers==0.0.23 --no-deps --index-url https://download.pytorch.org/whl/cu118"))
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors")
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
# check if there is a picture uploaded or selected
def check_img_input(control_image):
if control_image is None:
raise gr.Error("Please select or upload an input image")
# check if there is a picture uploaded or selected
def check_video_input(image_block: Image.Image):
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
raise gr.Error("Please generate a video first")
@spaces.GPU()
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
if not os.path.exists('tmp_data'):
os.makedirs('tmp_data')
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
if preprocess_chk:
# save image to a designated path
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
# preprocess image
print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
else:
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
# stage 1
subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
@spaces.GPU(duration=120)
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
# stage 2
subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
# return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
return [image_dir+f'/{t:03d}.ply' for t in range(28)]
if __name__ == "__main__":
_TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://jiawei-ren.github.io/projects/dreamgaussian4d/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2312.17142"><img src="https://img.shields.io/badge/2312.17142-f9f7f7?logo="></a>
<a style="display:inline-block; margin-left: .5em" href='https://github.com/jiawei-ren/dreamgaussian4d'><img src='https://img.shields.io/github/stars/jiawei-ren/dreamgaussian4d?style=social'/></a>
</div>
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
'''
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**."
# load images in 'data' folder as examples
example_folder = os.path.join(os.path.dirname(__file__), 'data')
example_fns = os.listdir(example_folder)
example_fns.sort()
examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
# Compose demo layout & data flow
with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
# Image-to-3D
with gr.Row(variant='panel'):
with gr.Column(scale=4):
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
gr.Markdown(
"random seed for video generation.")
preprocess_chk = gr.Checkbox(True,
label='Preprocess image automatically (remove background and recenter object)')
gr.Examples(
examples=examples_full, # NOTE: elements must match inputs list!
inputs=[image_block],
outputs=[image_block],
cache_examples=False,
label='Examples (click one of the images below to start)',
examples_per_page=40
)
img_run_btn = gr.Button("Generate Video")
fourd_run_btn = gr.Button("Generate 4D")
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
with gr.Column(scale=5):
obj3d = gr.Video(label="video",height=290)
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
inputs=[image_block,
preprocess_chk,
seed_slider],
outputs=[
obj3d])
fourd_run_btn.click(check_video_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
# demo.queue().launch(share=True)
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
demo.launch() |