LiDAR-Diffusion / app.py
Hancy's picture
Update app.py
062d491 verified
raw
history blame
2.6 kB
import gradio as gr
import spaces
import tempfile
import os
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from app_config import CSS, HEADER, FOOTER, DEVICE
import sample_cond
model = sample_cond.load_model()
def create_custom_colormap():
colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
positions = [0, 0.38, 0.6, 0.7, 1]
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
return custom_cmap
def colorize_depth(depth, log_scale):
if log_scale:
depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
mask = depth == 0
colormap = create_custom_colormap()
rgb = colormap(depth)[:, :, :3]
rgb[mask] = 0.
return rgb
@spaces.GPU
@torch.no_grad()
def generate_lidar(model, cond):
img, pcd = sample_cond.sample(model, cond)
return img, pcd
def load_camera(image):
split_per_view = 4
camera = np.array(image).astype(np.float32) / 255.
camera = camera.transpose(2, 0, 1)
camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views
camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
return camera_cond
with gr.Blocks(css=CSS) as demo:
gr.Markdown(HEADER)
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
with gr.Column():
output_image = gr.Image(label="Output Range Map", elem_id='img-display-output')
output_pcd = gr.Model3D(label="Output Point Cloud", elem_id='pcd-display-output', interactive=False)
# raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
submit = gr.Button("Generate")
def on_submit(image):
cond = load_camera(image)
img, pcd = generate_lidar(model, cond)
# tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
# pcd.save(tmp.name)
rgb_img = colorize_depth(img, log_scale=True)
return [rgb_img, pcd]
submit.click(on_submit, inputs=[input_image], outputs=[output_image, output_pcd])
example_files = sorted(os.listdir('cam_examples'))
example_files = [os.path.join('cam_examples', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, output_pcd],
fn=on_submit, cache_examples=False)
gr.Markdown(FOOTER)
if __name__ == '__main__':
demo.queue().launch()