gs-dynamics / app.py
kaifz's picture
Add application file
9d79799
raw
history blame
14.9 kB
import argparse
import numpy as np
import gradio as gr
import torch
import torch.nn.functional as F
import open3d as o3d
import time
import cv2
import math
import os
import yaml
import shutil
import sys
import glob
from functools import partial
import copy
from PIL import Image
class DynamicsVisualizer:
def __init__(self):
device = torch.device("cpu")
self.device = device
self.width = 640
self.height = 480
self.vis_cam_id = 1
self.bg_id = 0 # 0: black, 1: white
self.imgs = None
self.gs_orig = None
self.gs_pred = None
self.actions = None
self.videos = None
self.example_name = None
self.action_name = None
self.form_image_is_set = False
self.form_video_is_set = False
self.form_3dgs_orig_is_set = False
self.form_3dgs_pred_is_set = False
def load_example(self):
example_path = os.path.join('data', self.example_name)
self.imgs = [Image.open(os.path.join(example_path, f'img_{i}.png')) for i in range(4)]
self.gs_orig = os.path.join(example_path, 'gs_orig.splat')
def load_action(self):
action_path = os.path.join('data', self.action_name)
self.imgs = [Image.open(os.path.join(action_path, f'img_{i}.png')) for i in range(4)]
self.videos = [os.path.join(action_path, f'video_{i}.mp4') for i in range(4)]
self.gs_pred = os.path.join(action_path, 'gs_pred.splat')
def reset(self):
self.imgs = None
self.gs_orig = None
self.gs_pred = None
self.actions = None
self.videos = None
self.vis_cam_id = 1
self.bg_id = 0 # 0: black, 1: white
self.example_name = None
self.action_name = None
form_image = gr.Image(label='Initial state and actions', value=None, width=self.width, height=self.height)
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height)
form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=None)
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None)
self.form_image_is_set = False
self.form_video_is_set = False
self.form_3dgs_orig_is_set = False
self.form_3dgs_pred_is_set = False
return form_image, form_video, form_3dgs_orig, form_3dgs_pred
def on_click_set_example(self, state):
self.example_name = f"{int(state['example_id'])}"
self.load_example()
init_image = self.imgs[self.vis_cam_id]
form_image = gr.Image(label='Initial state and actions', value=init_image, width=self.width, height=self.height)
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height)
form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None)
self.form_image_is_set = True
self.form_video_is_set = False
self.form_3dgs_orig_is_set = True
self.form_3dgs_pred_is_set = False
return form_image, form_video, form_3dgs_orig, form_3dgs_pred
def on_click_set_action(self, state):
self.action_name = f"{self.example_name}/action-{int(state['action_id'])}"
self.load_action()
action_image = self.imgs[self.vis_cam_id]
form_image = gr.Image(label='Initial state and actions', value=action_image, width=self.width, height=self.height)
self.form_image_is_set = True
return form_image
def on_click_run(self):
form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height)
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
self.form_video_is_set = True
self.form_3dgs_pred_is_set = True
return form_video, form_3dgs_pred
def on_click_change_view(self, state):
self.vis_cam_id = int(state['view_id'])
form_image = gr.Image(label='Initial state and actions', value=self.imgs[self.vis_cam_id], width=self.width, height=self.height)
if self.form_video_is_set:
form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height)
else:
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height)
return form_image, form_video
# def on_click_change_bg(self):
# if self.bg_id == 0:
# self.bg_id = 1
# else:
# self.bg_id = 0
# if self.form_3dgs_orig_is_set:
# form_3dgs_orig = gr.Model3D(value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
# else:
# form_3dgs_orig = gr.Model3D(value=None)
# if self.form_3dgs_pred_is_set:
# form_3dgs_pred = gr.Model3D(value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
# else:
# form_3dgs_pred = gr.Model3D(value=None)
# return form_3dgs_orig, form_3dgs_pred
def launch(self, share=False):
with gr.Blocks() as app:
# with gr.Row():
# with gr.Column(scale=2):
# run_reset = gr.Button('Clear All')
# with gr.Column(scale=2):
# run_set_example = gr.Button('Set Example')
# with gr.Column(scale=2):
# run_set_action = gr.Button('Set Action')
# with gr.Column(scale=2):
# run_run = gr.Button('Run')
# with gr.Row():
# with gr.Column(scale=1, min_width=20):
# with gr.Row():
# run_view_0 = gr.Button('View 0')
# with gr.Row():
# run_view_1 = gr.Button('View 1')
# with gr.Row():
# run_view_2 = gr.Button('View 2')
# with gr.Row():
# run_view_3 = gr.Button('View 3')
with gr.Row():
gr.Markdown("# Dynamic 3D Gaussian Tracking for Graph-Based Neural Dynamics Modeling")
with gr.Row():
gr.Markdown('Project page: [https://gs-dynamics.github.io/](https://gs-dynamics.github.io/)')
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 0**: click **Clear All** to clear all window and reset the visualizer.")
with gr.Column(scale=1):
run_reset = gr.Button('Clear All')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 1**: select the object.")
with gr.Column(scale=1):
run_set_example_0 = gr.Button('Rope')
with gr.Column(scale=1):
run_set_example_1 = gr.Button('Rope - Long')
with gr.Column(scale=1):
run_set_example_2 = gr.Button('Toy Animal')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 2**: select the action.")
with gr.Column(scale=1):
run_set_action_0 = gr.Button('Action 1')
with gr.Column(scale=1):
run_set_action_1 = gr.Button('Action 2')
with gr.Column(scale=1):
run_set_action_2 = gr.Button('Action 3')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 3**: click **Run** to visualize the predicted video and Splats.")
with gr.Column(scale=1):
run_run = gr.Button('Run')
with gr.Row():
with gr.Column(scale=1, min_width=20):
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
# with gr.Row():
# gr.Markdown()
# with gr.Row():
# gr.Markdown()
with gr.Row():
gr.Markdown("Our model uses only 4 cameras for reconstructing the Gaussian Splats. Click the buttons below to change the view.")
with gr.Row():
run_view_0 = gr.Button('View 0')
with gr.Row():
run_view_1 = gr.Button('View 1')
with gr.Row():
run_view_2 = gr.Button('View 2')
with gr.Row():
run_view_3 = gr.Button('View 3')
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=2):
form_image = gr.Image(
label='Initial state and actions',
value=None,
width=self.width,
height=self.height,
)
with gr.Column(scale=2):
form_video = gr.Video(
label='Predicted video',
value=None,
width=self.width,
height=self.height,
)
with gr.Row():
# with gr.Column(scale=1, min_width=20):
# pass
# with gr.Row():
# change_bg = gr.Button('Black/White Background')
with gr.Column(scale=2):
form_3dgs_orig = gr.Model3D(
label='Original Gaussian Splats',
value=None,
)
with gr.Column(scale=2):
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussian Splats',
value=None,
)
with gr.Row():
gr.Markdown("## Notes:")
with gr.Row():
gr.Markdown("- Due to the computation constraints of Hugging Face Space, all results are precomputed. ")
with gr.Row():
gr.Markdown("- Training a GS for an object takes around 30 seconds. Prediction typically takes only 1-2 seconds for each push!")
with gr.Row():
gr.Markdown("- More examples may be added in the future. Stay tuned!")
# with gr.Row():
# with gr.Column(scale=1):
# gr.Markdown("You can change the view to any of the 4 cameras.")
# with gr.Column(scale=1):
# run_view_0 = gr.Button('View 1')
# with gr.Column(scale=1):
# run_view_1 = gr.Button('View 2')
# with gr.Column(scale=1):
# run_view_2 = gr.Button('View 3')
# with gr.Column(scale=1):
# run_view_3 = gr.Button('View 4')
# Set up callbacks
run_reset.click(self.reset,
inputs=[],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_example_0.click(self.on_click_set_example,
inputs=[gr.State({'example_id': 0})],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_example_1.click(self.on_click_set_example,
inputs=[gr.State({'example_id': 1})],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_example_2.click(self.on_click_set_example,
inputs=[gr.State({'example_id': 2})],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_action_0.click(self.on_click_set_action,
inputs=[gr.State({'action_id': 0})],
outputs=[form_image])
run_set_action_1.click(self.on_click_set_action,
inputs=[gr.State({'action_id': 1})],
outputs=[form_image])
run_set_action_2.click(self.on_click_set_action,
inputs=[gr.State({'action_id': 2})],
outputs=[form_image])
run_run.click(self.on_click_run,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_view_0.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 1})],
outputs=[form_image, form_video])
run_view_1.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 2})],
outputs=[form_image, form_video])
run_view_2.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 3})],
outputs=[form_image, form_video])
run_view_3.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 0})],
outputs=[form_image, form_video])
# change_bg.click(self.on_click_change_bg,
# inputs=[],
# outputs=[form_3dgs_orig, form_3dgs_pred])
app.launch(share=share)
if __name__ == '__main__':
visualizer = DynamicsVisualizer()
visualizer.launch(share=True)