File size: 6,710 Bytes
a8b0636
449b4e2
 
 
 
 
 
 
 
 
 
 
 
 
3e7a2b7
ebd9056
0d09a3a
ebd9056
0d09a3a
 
 
 
ebd9056
0d09a3a
 
 
 
ebd9056
0d09a3a
ebd9056
 
0323c1e
0c890d5
174e0c6
 
 
 
 
0c890d5
174e0c6
 
 
0d09a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebd9056
0d09a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebd9056
 
0d09a3a
 
 
 
ebd9056
 
 
 
0d09a3a
ebd9056
 
0d09a3a
 
 
 
 
 
 
 
 
 
 
 
ebd9056
 
0d09a3a
 
 
 
 
 
3e7a2b7
 
 
22fc8c6
0d09a3a
 
ebd9056
 
 
0d09a3a
 
ebd9056
0d09a3a
 
ebd9056
 
0d09a3a
 
3e7a2b7
22fc8c6
ebd9056
 
 
 
 
 
 
22fc8c6
ebd9056
0d09a3a
3e7a2b7
ebd9056
3e7a2b7
0c890d5
 
 
 
 
 
 
ebd9056
 
0d09a3a
ebd9056
 
0d09a3a
 
ebd9056
 
 
 
 
 
 
 
 
 
0d09a3a
 
ebd9056
 
 
0d09a3a
 
3e7a2b7
0d09a3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import spaces
import subprocess
import sys

# Ensure the package is installed from the Git repository
package_name = "vqasynth"  # Replace with the actual package name if different
git_repo_url = "git+https://github.com/remyxai/VQASynth.git"

try:
    __import__(package_name)
except ImportError:
    print(f"{package_name} not found. Installing from {git_repo_url}...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", git_repo_url])

import os
import uuid
import tempfile

import cv2
import open3d as o3d
import PIL
from PIL import Image

from vqasynth.depth import DepthEstimator
from vqasynth.localize import Localizer
from vqasynth.scene_fusion import SpatialSceneConstructor
from vqasynth.prompts import PromptGenerator

import numpy as np
import gradio as gr

import spacy

try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    # Download the model if it's not already available
    from spacy.cli import download

    download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

depth = DepthEstimator(from_onnx=False)
localizer = Localizer()
spatial_scene_constructor = SpatialSceneConstructor()
prompt_generator = PromptGenerator()


def combine_segmented_pointclouds(
    pointcloud_ply_files: list, captions: list, prompts: list, cache_dir: str
):
    """
    Process a list of segmented point clouds to combine two based on captions and return the resulting 3D point cloud and the identified prompt.

    Args:
        pointcloud_ply_files (list): List of file paths to `.pcd` files representing segmented point clouds.
        captions (list): List of captions corresponding to the segmented point clouds.
        prompts (list): List of prompts containing questions and answers about the captions.
        cache_dir (str): Directory to save the final `.ply` and `.obj` files.

    Returns:
        tuple: The path to the generated `.obj` file and the identified prompt text.
    """
    selected_prompt = None
    selected_indices = None
    for i, caption1 in enumerate(captions):
        for j, caption2 in enumerate(captions):
            if i != j:
                for prompt in prompts:
                    if caption1 in prompt and caption2 in prompt:
                        selected_prompt = prompt
                        selected_indices = (i, j)
                        break
                if selected_prompt:
                    break
        if selected_prompt:
            break

    if not selected_prompt or not selected_indices:
        raise ValueError("No prompt found containing two captions.")

    idx1, idx2 = selected_indices
    pointcloud_files = [pointcloud_ply_files[idx1], pointcloud_ply_files[idx2]]
    captions = [captions[idx1], captions[idx2]]

    combined_point_cloud = o3d.geometry.PointCloud()
    for idx, pointcloud_file in enumerate(pointcloud_files):
        pcd = o3d.io.read_point_cloud(pointcloud_file)
        if pcd.is_empty():
            continue

        combined_point_cloud += pcd

    if combined_point_cloud.is_empty():
        raise ValueError(
            "Combined point cloud is empty after loading the selected segments."
        )

    uuid_out = str(uuid.uuid4())
    ply_file = os.path.join(cache_dir, f"combined_output_{uuid_out}.ply")
    obj_file = os.path.join(cache_dir, f"combined_output_{uuid_out}.obj")

    o3d.io.write_point_cloud(ply_file, combined_point_cloud)

    mesh = o3d.io.read_triangle_mesh(ply_file)
    o3d.io.write_triangle_mesh(obj_file, mesh)

    return obj_file, selected_prompt


@spaces.GPU
def run_vqasynth_pipeline(image: PIL.Image, cache_dir: str):
    depth_map, focal_length = depth.run(image)
    masks, bounding_boxes, captions = localizer.run(image)
    pointcloud_data, cannonicalized = spatial_scene_constructor.run(
        str(0), image, depth_map, focal_length, masks, cache_dir
    )
    prompts = prompt_generator.run(captions, pointcloud_data, cannonicalized)
    obj_file, selected_prompt = combine_segmented_pointclouds(
        pointcloud_data, captions, prompts, cache_dir
    )
    return obj_file, selected_prompt


def process_image(image: str):
    # Use a persistent temporary directory to keep the .obj file accessible by Gradio
    temp_dir = tempfile.mkdtemp()
    image = Image.open(image).convert("RGB")
    obj_file, prompt = run_vqasynth_pipeline(image, temp_dir)
    return obj_file, prompt


def build_demo():
    with gr.Blocks() as demo:
        gr.Markdown(
            """
        # Synthesizing SpatialVQA Samples with VQASynth
        This space helps test the full [VQASynth](https://github.com/remyxai/VQASynth) scene reconstruction pipeline on a single image with visualizations. 
        ### [Github](https://github.com/remyxai/VQASynth) | [Collection](https://huggingface.co/collections/remyxai/spacevlms-66a3dbb924756d98e7aec678) 
        """
        )

        gr.Markdown(
            """
        ## Instructions
        Upload an image, and the tool will generate a corresponding 3D point cloud visualization of the objects found and an example prompt and response describing a spatial relationship between the objects.
        """
        )

        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="filepath", label="Upload an Image")
                generate_button = gr.Button("Generate")

            with gr.Column():
                model_output = gr.Model3D(label="3D Point Cloud")  # Only used as output
                caption_output = gr.Text(label="Caption")

        generate_button.click(
            process_image, inputs=image_input, outputs=[model_output, caption_output]
        )

        gr.Examples(
            examples=[
                ["./examples/warehouse_rgb.jpg"],
                ["./examples/spooky_doggy.png"],
                ["./examples/bee_and_flower.jpg"],
                ["./examples/gears.png"],
                ["./examples/road-through-dense-forest.jpg"],
            ],
            inputs=image_input,
            label="Example Images",
            examples_per_page=5,
        )

        gr.Markdown(
            """
        ## Citation
        ```
        @article{chen2024spatialvlm,
          title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities},
          author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei},
          journal = {arXiv preprint arXiv:2401.12168},
          year = {2024},
          url = {https://arxiv.org/abs/2401.12168},
        }
        ```
        """
        )

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.launch(share=True)