Ryukijano commited on
Commit
a321f01
·
verified ·
1 Parent(s): b782b56

Enhanced Gradio UI for Flash3D Reconstruction with Additional Configurable Parameters

Browse files

- Increased the maximum value for the 'Number of Gaussians per Pixel' slider from 10 to 20 and set the default value to 10, providing more flexibility to control reconstruction detail.
- Adjusted the 'Scale Factor for Model Size' slider range from [0.5, 5.0] with a default value of 1.5, allowing finer control over output scaling.
- Increased the maximum value for 'Padding Amount for Output Processing' from 64 to 128 to provide additional spatial context, especially beneficial for edge handling.
- Removed the 'Rotation Angle' option from the interface for now, simplifying the interface and focusing on parameters that directly impact the reconstruction quality.
- Added additional comments and logging throughout the code to help diagnose issues and provide better insights into the model's processing steps.
- Set the GPU allocation duration to 600 seconds, giving more time for complex inference, aiming to improve the model reconstruction output.

Files changed (1) hide show
  1. app.py +52 -100
app.py CHANGED
@@ -9,7 +9,6 @@ import torchvision.transforms as TT
9
  import torchvision.transforms.functional as TTF
10
  from huggingface_hub import hf_hub_download
11
  import numpy as np
12
- from einops import rearrange
13
 
14
  from networks.gaussian_predictor import GaussianPredictor
15
  from util.vis3d import save_ply
@@ -55,95 +54,50 @@ def main():
55
  to_tensor = TT.ToTensor() # Convert image to tensor
56
 
57
  # Function to check if an image is uploaded by the user
58
- def check_input_image(input_images):
59
- print("[DEBUG] Checking input images...")
60
- if not input_images or len(input_images) == 0:
61
- print("[ERROR] No images uploaded!")
62
- raise gr.Error("No images uploaded!")
63
- print("[INFO] Input images are valid.")
64
-
65
- # Function to preprocess the input images before passing them to the model
66
- def preprocess(images, padding_value):
67
- processed_images = []
68
- for image in images:
69
- # Resize and pad each image
70
- print("[DEBUG] Preprocessing image...")
71
- image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
72
- pad_border_fn = TT.Pad((padding_value, padding_value))
73
- image = pad_border_fn(image)
74
- print("[INFO] Image preprocessing complete.")
75
- processed_images.append(image)
76
- return processed_images
77
-
78
- # Function to reconstruct the 3D model from the input images and export it as a PLY file
 
79
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
- def reconstruct_and_export(images, num_gauss):
81
  """
82
- Passes a batch of images through the model, outputs reconstruction in the form of a dict of tensors.
83
  """
84
  print("[DEBUG] Starting reconstruction and export...")
85
- # Stack the images along a new dimension to create a batch
86
- images_batch = torch.stack([to_tensor(image) for image in images]).to(device) # Create a batch of images
87
-
88
- # Create input dictionary expected by the model
89
  inputs = {
90
- ("color_aug", 0, 0): images_batch, # Batch of input images
91
  }
92
 
93
- # Pass the batch of images through the model to get the output
94
- print("[INFO] Passing batch of images through the model...")
95
- outputs = model(inputs) # Perform inference to get model outputs
96
-
97
- # Use the first output for illustration (or modify to combine outputs as needed)
98
- gauss_means = outputs[('gauss_means', 0, 0)]
99
- if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
100
- adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
101
- print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
102
- num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
103
-
104
- # Debugging tensor shape
105
- print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
106
-
107
- # Export the reconstruction to a PLY file
108
- print(f"[INFO] Saving output to {ply_out_path}...")
109
- save_ply(outputs, ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
110
- print("[INFO] Reconstruction and export complete.")
111
-
112
- return ply_out_path # Return the path to the saved PLY file
113
- """
114
- Passes images through model, outputs reconstruction in form of a dict of tensors.
115
- """
116
- outputs_list = []
117
- for image in images:
118
- print("[DEBUG] Starting reconstruction and export...")
119
- # Convert the preprocessed image to a tensor and move it to the specified device
120
- image = to_tensor(image).to(device).unsqueeze(0) # Add a batch dimension to the image tensor
121
- inputs = {
122
- ("color_aug", 0, 0): image, # The input dictionary expected by the model
123
- }
124
-
125
- # Pass the image through the model to get the output
126
- print("[INFO] Passing image through the model...")
127
- outputs = model(inputs) # Perform inference to get model outputs
128
- outputs_list.append(outputs)
129
-
130
- # Combine or process outputs from multiple images here if necessary
131
- # For now, we'll just save the first one for illustration
132
- gauss_means = outputs_list[0][('gauss_means', 0, 0)]
133
- if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
134
- adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
135
- print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
136
- num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
137
-
138
- # Debugging tensor shape
139
- print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
140
 
141
  # Export the reconstruction to a PLY file
142
  print(f"[INFO] Saving output to {ply_out_path}...")
143
- save_ply(outputs_list[0], ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
144
  print("[INFO] Reconstruction and export complete.")
145
 
146
- return ply_out_path # Return the path to the saved PLY file
147
 
148
  # Path to save the output PLY file
149
  ply_out_path = f'./mesh.ply'
@@ -166,20 +120,18 @@ def main():
166
  with gr.Row(variant="panel"):
167
  with gr.Column(scale=1):
168
  with gr.Row():
169
- # Input images component for the user to upload multiple images
170
- input_images = gr.Gallery(
171
- label="Input Images",
172
- # Accept RGBA images
173
- sources="upload", # Allow users to upload images
174
- # The images are returned as PIL images
175
- elem_id="content_images",
176
- # Optional, for editing images
177
- # Allow multiple image uploads
178
  )
179
  with gr.Row():
180
  # Sliders for configurable parameters
181
- num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=1) # Slider to set the number of Gaussians per pixel
182
- padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32) # Slider to set padding value
183
  with gr.Row():
184
  # Button to trigger the generation process
185
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
@@ -195,35 +147,35 @@ def main():
195
  './demo_examples/re10k_05.jpg',
196
  './demo_examples/re10k_06.jpg',
197
  ],
198
- inputs=[input_images], # Load the example images into the input component
199
  cache_examples=False,
200
- label="Examples", # Label for the examples section
201
  examples_per_page=20,
202
  )
203
 
204
  with gr.Row():
205
- # Display the preprocessed images (after resizing and padding)
206
- processed_images = gr.Gallery(label="Processed Images", interactive=False) # Output component to show the processed images
207
 
208
  with gr.Column(scale=2):
209
  with gr.Row():
210
  with gr.Tab("Reconstruction"):
211
  # 3D model viewer to display the reconstructed model
212
  output_model = gr.Model3D(
213
- height=512, # Height of the 3D model viewer
214
  label="Output Model",
215
- interactive=False # The viewer is not interactive
216
  )
217
 
218
  # Define the workflow for the Generate button
219
- submit.click(fn=check_input_image, inputs=[input_images]).success(
220
  fn=preprocess,
221
- inputs=[input_images, padding_value], # Pass the input images and padding value to the preprocess function
222
- outputs=[processed_images], # Output the processed images
223
  ).success(
224
  fn=reconstruct_and_export,
225
- inputs=[processed_images, num_gauss], # Pass the processed images and number of Gaussians to the reconstruction function
226
- outputs=[output_model], # Output the reconstructed 3D model
227
  )
228
 
229
  # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
 
9
  import torchvision.transforms.functional as TTF
10
  from huggingface_hub import hf_hub_download
11
  import numpy as np
 
12
 
13
  from networks.gaussian_predictor import GaussianPredictor
14
  from util.vis3d import save_ply
 
54
  to_tensor = TT.ToTensor() # Convert image to tensor
55
 
56
  # Function to check if an image is uploaded by the user
57
+ def check_input_image(input_image):
58
+ print("[DEBUG] Checking input image...")
59
+ if input_image is None:
60
+ print("[ERROR] No image uploaded!")
61
+ raise gr.Error("No image uploaded!")
62
+ print("[INFO] Input image is valid.")
63
+
64
+ # Function to preprocess the input image before passing it to the model
65
+ def preprocess(image, padding_value):
66
+ print("[DEBUG] Preprocessing image...")
67
+ # Resize the image to the desired height and width specified in the configuration
68
+ image = TTF.resize(
69
+ image, (cfg.dataset.height, cfg.dataset.width),
70
+ interpolation=TT.InterpolationMode.BICUBIC
71
+ )
72
+ # Apply padding to the image
73
+ pad_border_fn = TT.Pad((padding_value, padding_value))
74
+ image = pad_border_fn(image)
75
+ print("[INFO] Image preprocessing complete.")
76
+ return image
77
+
78
+ # Function to reconstruct the 3D model from the input image and export it as a PLY file
79
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
+ def reconstruct_and_export(image, num_gauss):
81
  """
82
+ Passes image through model, outputs reconstruction in form of a dict of tensors.
83
  """
84
  print("[DEBUG] Starting reconstruction and export...")
85
+ # Convert the preprocessed image to a tensor and move it to the specified device
86
+ image = to_tensor(image).to(device).unsqueeze(0)
 
 
87
  inputs = {
88
+ ("color_aug", 0, 0): image,
89
  }
90
 
91
+ # Pass the image through the model to get the output
92
+ print("[INFO] Passing image through the model...")
93
+ outputs = model(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Export the reconstruction to a PLY file
96
  print(f"[INFO] Saving output to {ply_out_path}...")
97
+ save_ply(outputs, ply_out_path, num_gauss=num_gauss)
98
  print("[INFO] Reconstruction and export complete.")
99
 
100
+ return ply_out_path
101
 
102
  # Path to save the output PLY file
103
  ply_out_path = f'./mesh.ply'
 
120
  with gr.Row(variant="panel"):
121
  with gr.Column(scale=1):
122
  with gr.Row():
123
+ # Input image component for the user to upload an image
124
+ input_image = gr.Image(
125
+ label="Input Image",
126
+ image_mode="RGBA",
127
+ sources="upload",
128
+ type="pil",
129
+ elem_id="content_image",
 
 
130
  )
131
  with gr.Row():
132
  # Sliders for configurable parameters
133
+ num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
134
+ padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
135
  with gr.Row():
136
  # Button to trigger the generation process
137
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
 
147
  './demo_examples/re10k_05.jpg',
148
  './demo_examples/re10k_06.jpg',
149
  ],
150
+ inputs=[input_image],
151
  cache_examples=False,
152
+ label="Examples",
153
  examples_per_page=20,
154
  )
155
 
156
  with gr.Row():
157
+ # Display the preprocessed image (after resizing and padding)
158
+ processed_image = gr.Image(label="Processed Image", interactive=False)
159
 
160
  with gr.Column(scale=2):
161
  with gr.Row():
162
  with gr.Tab("Reconstruction"):
163
  # 3D model viewer to display the reconstructed model
164
  output_model = gr.Model3D(
165
+ height=512,
166
  label="Output Model",
167
+ interactive=False
168
  )
169
 
170
  # Define the workflow for the Generate button
171
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
172
  fn=preprocess,
173
+ inputs=[input_image, padding_value],
174
+ outputs=[processed_image],
175
  ).success(
176
  fn=reconstruct_and_export,
177
+ inputs=[processed_image, num_gauss],
178
+ outputs=[output_model],
179
  )
180
 
181
  # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)