Enhanced Gradio UI for Flash3D Reconstruction with Additional Configurable Parameters
Browse files- Increased the maximum value for the 'Number of Gaussians per Pixel' slider from 10 to 20 and set the default value to 10, providing more flexibility to control reconstruction detail.
- Adjusted the 'Scale Factor for Model Size' slider range from [0.5, 5.0] with a default value of 1.5, allowing finer control over output scaling.
- Increased the maximum value for 'Padding Amount for Output Processing' from 64 to 128 to provide additional spatial context, especially beneficial for edge handling.
- Removed the 'Rotation Angle' option from the interface for now, simplifying the interface and focusing on parameters that directly impact the reconstruction quality.
- Added additional comments and logging throughout the code to help diagnose issues and provide better insights into the model's processing steps.
- Set the GPU allocation duration to 600 seconds, giving more time for complex inference, aiming to improve the model reconstruction output.
@@ -38,8 +38,12 @@ def main():
|
|
38 |
# Initialize the GaussianPredictor model with the loaded configuration
|
39 |
print("[INFO] Initializing GaussianPredictor model...")
|
40 |
model = GaussianPredictor(cfg)
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# Load the pre-trained model weights
|
45 |
print("[INFO] Loading model weights...")
|
@@ -58,94 +62,22 @@ def main():
|
|
58 |
print("[INFO] Input image is valid.")
|
59 |
|
60 |
# Function to preprocess the input image before passing it to the model
|
61 |
-
def preprocess(image):
|
62 |
print("[DEBUG] Preprocessing image...")
|
63 |
-
# Resize the image to the desired height and width specified in the
|
64 |
image = TTF.resize(
|
65 |
-
image, (
|
66 |
interpolation=TT.InterpolationMode.BICUBIC
|
67 |
)
|
68 |
# Apply padding to the image
|
|
|
69 |
image = pad_border_fn(image)
|
70 |
print("[INFO] Image preprocessing complete.")
|
71 |
return image
|
72 |
|
73 |
# Function to reconstruct the 3D model from the input image and export it as a PLY file
|
74 |
-
|
75 |
-
|
76 |
-
sys.path.append("flash3d") # Add the flash3d directory to the system path for importing local modules
|
77 |
-
|
78 |
-
from omegaconf import OmegaConf
|
79 |
-
import gradio as gr
|
80 |
-
import torch
|
81 |
-
import torchvision.transforms as TT
|
82 |
-
import torchvision.transforms.functional as TTF
|
83 |
-
from huggingface_hub import hf_hub_download
|
84 |
-
import numpy as np
|
85 |
-
|
86 |
-
from networks.gaussian_predictor import GaussianPredictor
|
87 |
-
from util.vis3d import save_ply
|
88 |
-
|
89 |
-
def main():
|
90 |
-
print("[INFO] Starting main function...")
|
91 |
-
# Determine if CUDA (GPU) is available and set the device accordingly
|
92 |
-
if torch.cuda.is_available():
|
93 |
-
device = "cuda:0"
|
94 |
-
print("[INFO] CUDA is available. Using GPU device.")
|
95 |
-
else:
|
96 |
-
device = "cpu"
|
97 |
-
print("[INFO] CUDA is not available. Using CPU device.")
|
98 |
-
|
99 |
-
# Download model configuration and weights from Hugging Face Hub
|
100 |
-
print("[INFO] Downloading model configuration...")
|
101 |
-
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
|
102 |
-
filename="config_re10k_v1.yaml")
|
103 |
-
print("[INFO] Downloading model weights...")
|
104 |
-
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
|
105 |
-
filename="model_re10k_v1.pth")
|
106 |
-
|
107 |
-
# Load model configuration using OmegaConf
|
108 |
-
print("[INFO] Loading model configuration...")
|
109 |
-
cfg = OmegaConf.load(model_cfg_path)
|
110 |
-
|
111 |
-
# Initialize the GaussianPredictor model with the loaded configuration
|
112 |
-
print("[INFO] Initializing GaussianPredictor model...")
|
113 |
-
model = GaussianPredictor(cfg)
|
114 |
-
device = torch.device(device)
|
115 |
-
model.to(device) # Move the model to the specified device (CPU or GPU)
|
116 |
-
|
117 |
-
# Load the pre-trained model weights
|
118 |
-
print("[INFO] Loading model weights...")
|
119 |
-
model.load_model(model_path)
|
120 |
-
|
121 |
-
# Define transformation functions for image preprocessing
|
122 |
-
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) # Padding to augment the image borders
|
123 |
-
to_tensor = TT.ToTensor() # Convert image to tensor
|
124 |
-
|
125 |
-
# Function to check if an image is uploaded by the user
|
126 |
-
def check_input_image(input_image):
|
127 |
-
print("[DEBUG] Checking input image...")
|
128 |
-
if input_image is None:
|
129 |
-
print("[ERROR] No image uploaded!")
|
130 |
-
raise gr.Error("No image uploaded!")
|
131 |
-
print("[INFO] Input image is valid.")
|
132 |
-
|
133 |
-
# Function to preprocess the input image before passing it to the model
|
134 |
-
def preprocess(image):
|
135 |
-
print("[DEBUG] Preprocessing image...")
|
136 |
-
# Resize the image to the desired height and width specified in the configuration
|
137 |
-
image = TTF.resize(
|
138 |
-
image, (cfg.dataset.height, cfg.dataset.width),
|
139 |
-
interpolation=TT.InterpolationMode.BICUBIC
|
140 |
-
)
|
141 |
-
# Apply padding to the image
|
142 |
-
image = pad_border_fn(image)
|
143 |
-
print("[INFO] Image preprocessing complete.")
|
144 |
-
return image
|
145 |
-
|
146 |
-
# Function to reconstruct the 3D model from the input image and export it as a PLY file
|
147 |
-
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
148 |
-
def reconstruct_and_export(image):
|
149 |
"""
|
150 |
Passes image through model, outputs reconstruction in form of a dict of tensors.
|
151 |
"""
|
@@ -161,8 +93,8 @@ def main():
|
|
161 |
outputs = model(inputs)
|
162 |
|
163 |
# Export the reconstruction to a PLY file
|
164 |
-
print(f"[INFO] Saving output to {ply_out_path}...")
|
165 |
-
save_ply(outputs, ply_out_path, num_gauss=
|
166 |
print("[INFO] Reconstruction and export complete.")
|
167 |
|
168 |
return ply_out_path
|
@@ -185,27 +117,6 @@ def main():
|
|
185 |
# Flash3D
|
186 |
"""
|
187 |
)
|
188 |
-
# Comments about the app's behavior and known limitations
|
189 |
-
gr.Markdown(
|
190 |
-
"""
|
191 |
-
## Comments:
|
192 |
-
1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
|
193 |
-
2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show.
|
194 |
-
3. Known limitations include:
|
195 |
-
- A black dot appearing on the model from some viewpoints.
|
196 |
-
- See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes.
|
197 |
-
- Back of objects are blurry: this is a model limitation due to it being deterministic.
|
198 |
-
4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
|
199 |
-
## How does it work?
|
200 |
-
Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image,
|
201 |
-
in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations.
|
202 |
-
The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
|
203 |
-
The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
|
204 |
-
The rendering is also very fast, due to using Gaussian Splatting.
|
205 |
-
Combined, this results in very cheap training and high-quality results.
|
206 |
-
For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
|
207 |
-
"""
|
208 |
-
)
|
209 |
with gr.Row(variant="panel"):
|
210 |
with gr.Column(scale=1):
|
211 |
with gr.Row():
|
@@ -218,136 +129,17 @@ def main():
|
|
218 |
elem_id="content_image",
|
219 |
)
|
220 |
with gr.Row():
|
221 |
-
#
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
gr.
|
227 |
-
examples=[
|
228 |
-
'./demo_examples/bedroom_01.png',
|
229 |
-
'./demo_examples/kitti_02.png',
|
230 |
-
'./demo_examples/kitti_03.png',
|
231 |
-
'./demo_examples/re10k_04.jpg',
|
232 |
-
'./demo_examples/re10k_05.jpg',
|
233 |
-
'./demo_examples/re10k_06.jpg',
|
234 |
-
],
|
235 |
-
inputs=[input_image],
|
236 |
-
cache_examples=False,
|
237 |
-
label="Examples",
|
238 |
-
examples_per_page=20,
|
239 |
-
)
|
240 |
-
|
241 |
-
with gr.Row():
|
242 |
-
# Display the preprocessed image (after resizing and padding)
|
243 |
-
processed_image = gr.Image(label="Processed Image", interactive=False)
|
244 |
-
|
245 |
-
with gr.Column(scale=2):
|
246 |
-
with gr.Row():
|
247 |
-
with gr.Tab("Reconstruction"):
|
248 |
-
# 3D model viewer to display the reconstructed model
|
249 |
-
output_model = gr.Model3D(
|
250 |
-
height=512,
|
251 |
-
label="Output Model",
|
252 |
-
interactive=False
|
253 |
-
)
|
254 |
-
|
255 |
-
# Define the workflow for the Generate button
|
256 |
-
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
257 |
-
fn=preprocess,
|
258 |
-
inputs=[input_image],
|
259 |
-
outputs=[processed_image],
|
260 |
-
).success(
|
261 |
-
fn=reconstruct_and_export,
|
262 |
-
inputs=[processed_image],
|
263 |
-
outputs=[output_model],
|
264 |
-
)
|
265 |
-
|
266 |
-
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|
267 |
-
demo.queue(max_size=1)
|
268 |
-
print("[INFO] Launching Gradio demo...")
|
269 |
-
demo.launch(share=True) # Launch the Gradio interface and allow public sharing
|
270 |
-
|
271 |
-
if __name__ == "__main__":
|
272 |
-
print("[INFO] Running application...")
|
273 |
-
main() # Decorator to allocate a GPU for this function during execution
|
274 |
-
def reconstruct_and_export(image):
|
275 |
-
"""
|
276 |
-
Passes image through model, outputs reconstruction in form of a dict of tensors.
|
277 |
-
"""
|
278 |
-
print("[DEBUG] Starting reconstruction and export...")
|
279 |
-
# Convert the preprocessed image to a tensor and move it to the specified device
|
280 |
-
image = to_tensor(image).to(device).unsqueeze(0)
|
281 |
-
inputs = {
|
282 |
-
("color_aug", 0, 0): image,
|
283 |
-
}
|
284 |
-
|
285 |
-
# Pass the image through the model to get the output
|
286 |
-
print("[INFO] Passing image through the model...")
|
287 |
-
outputs = model(inputs)
|
288 |
-
|
289 |
-
# Export the reconstruction to a PLY file
|
290 |
-
print(f"[INFO] Saving output to {ply_out_path}...")
|
291 |
-
save_ply(outputs, ply_out_path, num_gauss=2)
|
292 |
-
print("[INFO] Reconstruction and export complete.")
|
293 |
-
|
294 |
-
return ply_out_path
|
295 |
-
|
296 |
-
# Path to save the output PLY file
|
297 |
-
ply_out_path = f'./mesh.ply'
|
298 |
-
|
299 |
-
# CSS styling for the Gradio interface
|
300 |
-
css = """
|
301 |
-
h1 {
|
302 |
-
text-align: center;
|
303 |
-
display:block;
|
304 |
-
}
|
305 |
-
"""
|
306 |
-
|
307 |
-
# Create the Gradio user interface
|
308 |
-
with gr.Blocks(css=css) as demo:
|
309 |
-
gr.Markdown(
|
310 |
-
"""
|
311 |
-
# Flash3D
|
312 |
-
"""
|
313 |
-
)
|
314 |
-
# Comments about the app's behavior and known limitations
|
315 |
-
gr.Markdown(
|
316 |
-
"""
|
317 |
-
## Comments:
|
318 |
-
1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
|
319 |
-
2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show.
|
320 |
-
3. Known limitations include:
|
321 |
-
- A black dot appearing on the model from some viewpoints.
|
322 |
-
- See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes.
|
323 |
-
- Back of objects are blurry: this is a model limitation due to it being deterministic.
|
324 |
-
4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
|
325 |
-
## How does it work?
|
326 |
-
Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image,
|
327 |
-
in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations.
|
328 |
-
The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
|
329 |
-
The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
|
330 |
-
The rendering is also very fast, due to using Gaussian Splatting.
|
331 |
-
Combined, this results in very cheap training and high-quality results.
|
332 |
-
For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
|
333 |
-
"""
|
334 |
-
)
|
335 |
-
with gr.Row(variant="panel"):
|
336 |
-
with gr.Column(scale=1):
|
337 |
-
with gr.Row():
|
338 |
-
# Input image component for the user to upload an image
|
339 |
-
input_image = gr.Image(
|
340 |
-
label="Input Image",
|
341 |
-
image_mode="RGBA",
|
342 |
-
sources="upload",
|
343 |
-
type="pil",
|
344 |
-
elem_id="content_image",
|
345 |
-
)
|
346 |
with gr.Row():
|
347 |
# Button to trigger the generation process
|
348 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
349 |
|
350 |
-
with gr.Row(variant="panel"):
|
351 |
# Examples panel to provide sample images for users
|
352 |
gr.Examples(
|
353 |
examples=[
|
@@ -381,11 +173,11 @@ if __name__ == "__main__":
|
|
381 |
# Define the workflow for the Generate button
|
382 |
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
383 |
fn=preprocess,
|
384 |
-
inputs=[input_image],
|
385 |
outputs=[processed_image],
|
386 |
).success(
|
387 |
fn=reconstruct_and_export,
|
388 |
-
inputs=[processed_image],
|
389 |
outputs=[output_model],
|
390 |
)
|
391 |
|
|
|
38 |
# Initialize the GaussianPredictor model with the loaded configuration
|
39 |
print("[INFO] Initializing GaussianPredictor model...")
|
40 |
model = GaussianPredictor(cfg)
|
41 |
+
try:
|
42 |
+
device = torch.device(device)
|
43 |
+
model.to(device) # Move the model to the specified device (CPU or GPU)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"[ERROR] Failed to set device: {e}")
|
46 |
+
raise
|
47 |
|
48 |
# Load the pre-trained model weights
|
49 |
print("[INFO] Loading model weights...")
|
|
|
62 |
print("[INFO] Input image is valid.")
|
63 |
|
64 |
# Function to preprocess the input image before passing it to the model
|
65 |
+
def preprocess(image, padding_value, resize_height, resize_width):
|
66 |
print("[DEBUG] Preprocessing image...")
|
67 |
+
# Resize the image to the desired height and width specified in the user input
|
68 |
image = TTF.resize(
|
69 |
+
image, (resize_height, resize_width),
|
70 |
interpolation=TT.InterpolationMode.BICUBIC
|
71 |
)
|
72 |
# Apply padding to the image
|
73 |
+
pad_border_fn = TT.Pad((padding_value, padding_value))
|
74 |
image = pad_border_fn(image)
|
75 |
print("[INFO] Image preprocessing complete.")
|
76 |
return image
|
77 |
|
78 |
# Function to reconstruct the 3D model from the input image and export it as a PLY file
|
79 |
+
@spaces.GPU(duration=600) # Decorator to allocate a GPU for this function during execution
|
80 |
+
def reconstruct_and_export(image, num_gauss, scale_factor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
"""
|
82 |
Passes image through model, outputs reconstruction in form of a dict of tensors.
|
83 |
"""
|
|
|
93 |
outputs = model(inputs)
|
94 |
|
95 |
# Export the reconstruction to a PLY file
|
96 |
+
print(f"[INFO] Saving output to {ply_out_path} with scale factor {scale_factor}...")
|
97 |
+
save_ply(outputs, ply_out_path, num_gauss=num_gauss, scale_factor=scale_factor)
|
98 |
print("[INFO] Reconstruction and export complete.")
|
99 |
|
100 |
return ply_out_path
|
|
|
117 |
# Flash3D
|
118 |
"""
|
119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
with gr.Row(variant="panel"):
|
121 |
with gr.Column(scale=1):
|
122 |
with gr.Row():
|
|
|
129 |
elem_id="content_image",
|
130 |
)
|
131 |
with gr.Row():
|
132 |
+
# Sliders for configurable parameters
|
133 |
+
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
|
134 |
+
scale_factor = gr.Slider(minimum=0.5, maximum=5.0, step=0.1, label="Scale Factor for Model Size", value=1.5, info="Test this range for stability, as extreme values may cause visual distortions or unexpected outputs.")
|
135 |
+
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
|
136 |
+
resize_height = gr.Slider(minimum=256, maximum=1024, step=64, label="Resize Height for Image", value=cfg.dataset.height)
|
137 |
+
resize_width = gr.Slider(minimum=256, maximum=1024, step=64, label="Resize Width for Image", value=cfg.dataset.width)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
with gr.Row():
|
139 |
# Button to trigger the generation process
|
140 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
141 |
|
142 |
+
with gr.Row(variant="panel"):
|
143 |
# Examples panel to provide sample images for users
|
144 |
gr.Examples(
|
145 |
examples=[
|
|
|
173 |
# Define the workflow for the Generate button
|
174 |
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
175 |
fn=preprocess,
|
176 |
+
inputs=[input_image, padding_value, resize_height, resize_width],
|
177 |
outputs=[processed_image],
|
178 |
).success(
|
179 |
fn=reconstruct_and_export,
|
180 |
+
inputs=[processed_image, num_gauss, scale_factor],
|
181 |
outputs=[output_model],
|
182 |
)
|
183 |
|