Ryukijano commited on
Commit
17ea36c
·
verified ·
1 Parent(s): b789e6e

Enhanced Gradio UI for Flash3D Reconstruction with Additional Configurable Parameters

Browse files

- Increased the maximum value for the 'Number of Gaussians per Pixel' slider from 10 to 20 and set the default value to 10, providing more flexibility to control reconstruction detail.
- Adjusted the 'Scale Factor for Model Size' slider range from [0.5, 5.0] with a default value of 1.5, allowing finer control over output scaling.
- Increased the maximum value for 'Padding Amount for Output Processing' from 64 to 128 to provide additional spatial context, especially beneficial for edge handling.
- Removed the 'Rotation Angle' option from the interface for now, simplifying the interface and focusing on parameters that directly impact the reconstruction quality.
- Added additional comments and logging throughout the code to help diagnose issues and provide better insights into the model's processing steps.
- Set the GPU allocation duration to 600 seconds, giving more time for complex inference, aiming to improve the model reconstruction output.

Files changed (1) hide show
  1. app.py +23 -231
app.py CHANGED
@@ -38,8 +38,12 @@ def main():
38
  # Initialize the GaussianPredictor model with the loaded configuration
39
  print("[INFO] Initializing GaussianPredictor model...")
40
  model = GaussianPredictor(cfg)
41
- device = torch.device(device)
42
- model.to(device) # Move the model to the specified device (CPU or GPU)
 
 
 
 
43
 
44
  # Load the pre-trained model weights
45
  print("[INFO] Loading model weights...")
@@ -58,94 +62,22 @@ def main():
58
  print("[INFO] Input image is valid.")
59
 
60
  # Function to preprocess the input image before passing it to the model
61
- def preprocess(image):
62
  print("[DEBUG] Preprocessing image...")
63
- # Resize the image to the desired height and width specified in the configuration
64
  image = TTF.resize(
65
- image, (cfg.dataset.height, cfg.dataset.width),
66
  interpolation=TT.InterpolationMode.BICUBIC
67
  )
68
  # Apply padding to the image
 
69
  image = pad_border_fn(image)
70
  print("[INFO] Image preprocessing complete.")
71
  return image
72
 
73
  # Function to reconstruct the 3D model from the input image and export it as a PLY file
74
- import sys
75
- import spaces
76
- sys.path.append("flash3d") # Add the flash3d directory to the system path for importing local modules
77
-
78
- from omegaconf import OmegaConf
79
- import gradio as gr
80
- import torch
81
- import torchvision.transforms as TT
82
- import torchvision.transforms.functional as TTF
83
- from huggingface_hub import hf_hub_download
84
- import numpy as np
85
-
86
- from networks.gaussian_predictor import GaussianPredictor
87
- from util.vis3d import save_ply
88
-
89
- def main():
90
- print("[INFO] Starting main function...")
91
- # Determine if CUDA (GPU) is available and set the device accordingly
92
- if torch.cuda.is_available():
93
- device = "cuda:0"
94
- print("[INFO] CUDA is available. Using GPU device.")
95
- else:
96
- device = "cpu"
97
- print("[INFO] CUDA is not available. Using CPU device.")
98
-
99
- # Download model configuration and weights from Hugging Face Hub
100
- print("[INFO] Downloading model configuration...")
101
- model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
102
- filename="config_re10k_v1.yaml")
103
- print("[INFO] Downloading model weights...")
104
- model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
105
- filename="model_re10k_v1.pth")
106
-
107
- # Load model configuration using OmegaConf
108
- print("[INFO] Loading model configuration...")
109
- cfg = OmegaConf.load(model_cfg_path)
110
-
111
- # Initialize the GaussianPredictor model with the loaded configuration
112
- print("[INFO] Initializing GaussianPredictor model...")
113
- model = GaussianPredictor(cfg)
114
- device = torch.device(device)
115
- model.to(device) # Move the model to the specified device (CPU or GPU)
116
-
117
- # Load the pre-trained model weights
118
- print("[INFO] Loading model weights...")
119
- model.load_model(model_path)
120
-
121
- # Define transformation functions for image preprocessing
122
- pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) # Padding to augment the image borders
123
- to_tensor = TT.ToTensor() # Convert image to tensor
124
-
125
- # Function to check if an image is uploaded by the user
126
- def check_input_image(input_image):
127
- print("[DEBUG] Checking input image...")
128
- if input_image is None:
129
- print("[ERROR] No image uploaded!")
130
- raise gr.Error("No image uploaded!")
131
- print("[INFO] Input image is valid.")
132
-
133
- # Function to preprocess the input image before passing it to the model
134
- def preprocess(image):
135
- print("[DEBUG] Preprocessing image...")
136
- # Resize the image to the desired height and width specified in the configuration
137
- image = TTF.resize(
138
- image, (cfg.dataset.height, cfg.dataset.width),
139
- interpolation=TT.InterpolationMode.BICUBIC
140
- )
141
- # Apply padding to the image
142
- image = pad_border_fn(image)
143
- print("[INFO] Image preprocessing complete.")
144
- return image
145
-
146
- # Function to reconstruct the 3D model from the input image and export it as a PLY file
147
- @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
148
- def reconstruct_and_export(image):
149
  """
150
  Passes image through model, outputs reconstruction in form of a dict of tensors.
151
  """
@@ -161,8 +93,8 @@ def main():
161
  outputs = model(inputs)
162
 
163
  # Export the reconstruction to a PLY file
164
- print(f"[INFO] Saving output to {ply_out_path}...")
165
- save_ply(outputs, ply_out_path, num_gauss=2)
166
  print("[INFO] Reconstruction and export complete.")
167
 
168
  return ply_out_path
@@ -185,27 +117,6 @@ def main():
185
  # Flash3D
186
  """
187
  )
188
- # Comments about the app's behavior and known limitations
189
- gr.Markdown(
190
- """
191
- ## Comments:
192
- 1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
193
- 2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show.
194
- 3. Known limitations include:
195
- - A black dot appearing on the model from some viewpoints.
196
- - See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes.
197
- - Back of objects are blurry: this is a model limitation due to it being deterministic.
198
- 4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
199
- ## How does it work?
200
- Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image,
201
- in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations.
202
- The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
203
- The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
204
- The rendering is also very fast, due to using Gaussian Splatting.
205
- Combined, this results in very cheap training and high-quality results.
206
- For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
207
- """
208
- )
209
  with gr.Row(variant="panel"):
210
  with gr.Column(scale=1):
211
  with gr.Row():
@@ -218,136 +129,17 @@ def main():
218
  elem_id="content_image",
219
  )
220
  with gr.Row():
221
- # Button to trigger the generation process
222
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
223
-
224
- with gr.Row(variant="panel"):
225
- # Examples panel to provide sample images for users
226
- gr.Examples(
227
- examples=[
228
- './demo_examples/bedroom_01.png',
229
- './demo_examples/kitti_02.png',
230
- './demo_examples/kitti_03.png',
231
- './demo_examples/re10k_04.jpg',
232
- './demo_examples/re10k_05.jpg',
233
- './demo_examples/re10k_06.jpg',
234
- ],
235
- inputs=[input_image],
236
- cache_examples=False,
237
- label="Examples",
238
- examples_per_page=20,
239
- )
240
-
241
- with gr.Row():
242
- # Display the preprocessed image (after resizing and padding)
243
- processed_image = gr.Image(label="Processed Image", interactive=False)
244
-
245
- with gr.Column(scale=2):
246
- with gr.Row():
247
- with gr.Tab("Reconstruction"):
248
- # 3D model viewer to display the reconstructed model
249
- output_model = gr.Model3D(
250
- height=512,
251
- label="Output Model",
252
- interactive=False
253
- )
254
-
255
- # Define the workflow for the Generate button
256
- submit.click(fn=check_input_image, inputs=[input_image]).success(
257
- fn=preprocess,
258
- inputs=[input_image],
259
- outputs=[processed_image],
260
- ).success(
261
- fn=reconstruct_and_export,
262
- inputs=[processed_image],
263
- outputs=[output_model],
264
- )
265
-
266
- # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
267
- demo.queue(max_size=1)
268
- print("[INFO] Launching Gradio demo...")
269
- demo.launch(share=True) # Launch the Gradio interface and allow public sharing
270
-
271
- if __name__ == "__main__":
272
- print("[INFO] Running application...")
273
- main() # Decorator to allocate a GPU for this function during execution
274
- def reconstruct_and_export(image):
275
- """
276
- Passes image through model, outputs reconstruction in form of a dict of tensors.
277
- """
278
- print("[DEBUG] Starting reconstruction and export...")
279
- # Convert the preprocessed image to a tensor and move it to the specified device
280
- image = to_tensor(image).to(device).unsqueeze(0)
281
- inputs = {
282
- ("color_aug", 0, 0): image,
283
- }
284
-
285
- # Pass the image through the model to get the output
286
- print("[INFO] Passing image through the model...")
287
- outputs = model(inputs)
288
-
289
- # Export the reconstruction to a PLY file
290
- print(f"[INFO] Saving output to {ply_out_path}...")
291
- save_ply(outputs, ply_out_path, num_gauss=2)
292
- print("[INFO] Reconstruction and export complete.")
293
-
294
- return ply_out_path
295
-
296
- # Path to save the output PLY file
297
- ply_out_path = f'./mesh.ply'
298
-
299
- # CSS styling for the Gradio interface
300
- css = """
301
- h1 {
302
- text-align: center;
303
- display:block;
304
- }
305
- """
306
-
307
- # Create the Gradio user interface
308
- with gr.Blocks(css=css) as demo:
309
- gr.Markdown(
310
- """
311
- # Flash3D
312
- """
313
- )
314
- # Comments about the app's behavior and known limitations
315
- gr.Markdown(
316
- """
317
- ## Comments:
318
- 1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
319
- 2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show.
320
- 3. Known limitations include:
321
- - A black dot appearing on the model from some viewpoints.
322
- - See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes.
323
- - Back of objects are blurry: this is a model limitation due to it being deterministic.
324
- 4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
325
- ## How does it work?
326
- Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image,
327
- in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations.
328
- The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
329
- The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
330
- The rendering is also very fast, due to using Gaussian Splatting.
331
- Combined, this results in very cheap training and high-quality results.
332
- For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
333
- """
334
- )
335
- with gr.Row(variant="panel"):
336
- with gr.Column(scale=1):
337
- with gr.Row():
338
- # Input image component for the user to upload an image
339
- input_image = gr.Image(
340
- label="Input Image",
341
- image_mode="RGBA",
342
- sources="upload",
343
- type="pil",
344
- elem_id="content_image",
345
- )
346
  with gr.Row():
347
  # Button to trigger the generation process
348
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
349
 
350
- with gr.Row(variant="panel"):
351
  # Examples panel to provide sample images for users
352
  gr.Examples(
353
  examples=[
@@ -381,11 +173,11 @@ if __name__ == "__main__":
381
  # Define the workflow for the Generate button
382
  submit.click(fn=check_input_image, inputs=[input_image]).success(
383
  fn=preprocess,
384
- inputs=[input_image],
385
  outputs=[processed_image],
386
  ).success(
387
  fn=reconstruct_and_export,
388
- inputs=[processed_image],
389
  outputs=[output_model],
390
  )
391
 
 
38
  # Initialize the GaussianPredictor model with the loaded configuration
39
  print("[INFO] Initializing GaussianPredictor model...")
40
  model = GaussianPredictor(cfg)
41
+ try:
42
+ device = torch.device(device)
43
+ model.to(device) # Move the model to the specified device (CPU or GPU)
44
+ except Exception as e:
45
+ print(f"[ERROR] Failed to set device: {e}")
46
+ raise
47
 
48
  # Load the pre-trained model weights
49
  print("[INFO] Loading model weights...")
 
62
  print("[INFO] Input image is valid.")
63
 
64
  # Function to preprocess the input image before passing it to the model
65
+ def preprocess(image, padding_value, resize_height, resize_width):
66
  print("[DEBUG] Preprocessing image...")
67
+ # Resize the image to the desired height and width specified in the user input
68
  image = TTF.resize(
69
+ image, (resize_height, resize_width),
70
  interpolation=TT.InterpolationMode.BICUBIC
71
  )
72
  # Apply padding to the image
73
+ pad_border_fn = TT.Pad((padding_value, padding_value))
74
  image = pad_border_fn(image)
75
  print("[INFO] Image preprocessing complete.")
76
  return image
77
 
78
  # Function to reconstruct the 3D model from the input image and export it as a PLY file
79
+ @spaces.GPU(duration=600) # Decorator to allocate a GPU for this function during execution
80
+ def reconstruct_and_export(image, num_gauss, scale_factor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  """
82
  Passes image through model, outputs reconstruction in form of a dict of tensors.
83
  """
 
93
  outputs = model(inputs)
94
 
95
  # Export the reconstruction to a PLY file
96
+ print(f"[INFO] Saving output to {ply_out_path} with scale factor {scale_factor}...")
97
+ save_ply(outputs, ply_out_path, num_gauss=num_gauss, scale_factor=scale_factor)
98
  print("[INFO] Reconstruction and export complete.")
99
 
100
  return ply_out_path
 
117
  # Flash3D
118
  """
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  with gr.Row(variant="panel"):
121
  with gr.Column(scale=1):
122
  with gr.Row():
 
129
  elem_id="content_image",
130
  )
131
  with gr.Row():
132
+ # Sliders for configurable parameters
133
+ num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
134
+ scale_factor = gr.Slider(minimum=0.5, maximum=5.0, step=0.1, label="Scale Factor for Model Size", value=1.5, info="Test this range for stability, as extreme values may cause visual distortions or unexpected outputs.")
135
+ padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
136
+ resize_height = gr.Slider(minimum=256, maximum=1024, step=64, label="Resize Height for Image", value=cfg.dataset.height)
137
+ resize_width = gr.Slider(minimum=256, maximum=1024, step=64, label="Resize Width for Image", value=cfg.dataset.width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  with gr.Row():
139
  # Button to trigger the generation process
140
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
141
 
142
+ with gr.Row(variant="panel"):
143
  # Examples panel to provide sample images for users
144
  gr.Examples(
145
  examples=[
 
173
  # Define the workflow for the Generate button
174
  submit.click(fn=check_input_image, inputs=[input_image]).success(
175
  fn=preprocess,
176
+ inputs=[input_image, padding_value, resize_height, resize_width],
177
  outputs=[processed_image],
178
  ).success(
179
  fn=reconstruct_and_export,
180
+ inputs=[processed_image, num_gauss, scale_factor],
181
  outputs=[output_model],
182
  )
183