Enhanced Gradio UI for Flash3D Reconstruction with Additional Configurable Parameters
Browse files- Increased the maximum value for the 'Number of Gaussians per Pixel' slider from 10 to 20 and set the default value to 10, providing more flexibility to control reconstruction detail.
- Adjusted the 'Scale Factor for Model Size' slider range from [0.5, 5.0] with a default value of 1.5, allowing finer control over output scaling.
- Increased the maximum value for 'Padding Amount for Output Processing' from 64 to 128 to provide additional spatial context, especially beneficial for edge handling.
- Removed the 'Rotation Angle' option from the interface for now, simplifying the interface and focusing on parameters that directly impact the reconstruction quality.
- Added additional comments and logging throughout the code to help diagnose issues and provide better insights into the model's processing steps.
- Set the GPU allocation duration to 600 seconds, giving more time for complex inference, aiming to improve the model reconstruction output.
@@ -9,7 +9,6 @@ import torchvision.transforms as TT
|
|
9 |
import torchvision.transforms.functional as TTF
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import numpy as np
|
12 |
-
from einops import rearrange
|
13 |
|
14 |
from networks.gaussian_predictor import GaussianPredictor
|
15 |
from util.vis3d import save_ply
|
@@ -55,95 +54,50 @@ def main():
|
|
55 |
to_tensor = TT.ToTensor() # Convert image to tensor
|
56 |
|
57 |
# Function to check if an image is uploaded by the user
|
58 |
-
def check_input_image(
|
59 |
-
print("[DEBUG] Checking input
|
60 |
-
if
|
61 |
-
print("[ERROR] No
|
62 |
-
raise gr.Error("No
|
63 |
-
print("[INFO] Input
|
64 |
-
|
65 |
-
# Function to preprocess the input
|
66 |
-
def preprocess(
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
80 |
-
def reconstruct_and_export(
|
81 |
"""
|
82 |
-
Passes
|
83 |
"""
|
84 |
print("[DEBUG] Starting reconstruction and export...")
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
# Create input dictionary expected by the model
|
89 |
inputs = {
|
90 |
-
("color_aug", 0, 0):
|
91 |
}
|
92 |
|
93 |
-
# Pass the
|
94 |
-
print("[INFO] Passing
|
95 |
-
outputs = model(inputs)
|
96 |
-
|
97 |
-
# Use the first output for illustration (or modify to combine outputs as needed)
|
98 |
-
gauss_means = outputs[('gauss_means', 0, 0)]
|
99 |
-
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
|
100 |
-
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
|
101 |
-
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
|
102 |
-
num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
|
103 |
-
|
104 |
-
# Debugging tensor shape
|
105 |
-
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
|
106 |
-
|
107 |
-
# Export the reconstruction to a PLY file
|
108 |
-
print(f"[INFO] Saving output to {ply_out_path}...")
|
109 |
-
save_ply(outputs, ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
|
110 |
-
print("[INFO] Reconstruction and export complete.")
|
111 |
-
|
112 |
-
return ply_out_path # Return the path to the saved PLY file
|
113 |
-
"""
|
114 |
-
Passes images through model, outputs reconstruction in form of a dict of tensors.
|
115 |
-
"""
|
116 |
-
outputs_list = []
|
117 |
-
for image in images:
|
118 |
-
print("[DEBUG] Starting reconstruction and export...")
|
119 |
-
# Convert the preprocessed image to a tensor and move it to the specified device
|
120 |
-
image = to_tensor(image).to(device).unsqueeze(0) # Add a batch dimension to the image tensor
|
121 |
-
inputs = {
|
122 |
-
("color_aug", 0, 0): image, # The input dictionary expected by the model
|
123 |
-
}
|
124 |
-
|
125 |
-
# Pass the image through the model to get the output
|
126 |
-
print("[INFO] Passing image through the model...")
|
127 |
-
outputs = model(inputs) # Perform inference to get model outputs
|
128 |
-
outputs_list.append(outputs)
|
129 |
-
|
130 |
-
# Combine or process outputs from multiple images here if necessary
|
131 |
-
# For now, we'll just save the first one for illustration
|
132 |
-
gauss_means = outputs_list[0][('gauss_means', 0, 0)]
|
133 |
-
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
|
134 |
-
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
|
135 |
-
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
|
136 |
-
num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
|
137 |
-
|
138 |
-
# Debugging tensor shape
|
139 |
-
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
|
140 |
|
141 |
# Export the reconstruction to a PLY file
|
142 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
143 |
-
save_ply(
|
144 |
print("[INFO] Reconstruction and export complete.")
|
145 |
|
146 |
-
return ply_out_path
|
147 |
|
148 |
# Path to save the output PLY file
|
149 |
ply_out_path = f'./mesh.ply'
|
@@ -166,20 +120,18 @@ def main():
|
|
166 |
with gr.Row(variant="panel"):
|
167 |
with gr.Column(scale=1):
|
168 |
with gr.Row():
|
169 |
-
# Input
|
170 |
-
|
171 |
-
label="Input
|
172 |
-
|
173 |
-
sources="upload",
|
174 |
-
|
175 |
-
elem_id="
|
176 |
-
# Optional, for editing images
|
177 |
-
# Allow multiple image uploads
|
178 |
)
|
179 |
with gr.Row():
|
180 |
# Sliders for configurable parameters
|
181 |
-
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=
|
182 |
-
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
|
183 |
with gr.Row():
|
184 |
# Button to trigger the generation process
|
185 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
@@ -195,35 +147,35 @@ def main():
|
|
195 |
'./demo_examples/re10k_05.jpg',
|
196 |
'./demo_examples/re10k_06.jpg',
|
197 |
],
|
198 |
-
inputs=[
|
199 |
cache_examples=False,
|
200 |
-
label="Examples",
|
201 |
examples_per_page=20,
|
202 |
)
|
203 |
|
204 |
with gr.Row():
|
205 |
-
# Display the preprocessed
|
206 |
-
|
207 |
|
208 |
with gr.Column(scale=2):
|
209 |
with gr.Row():
|
210 |
with gr.Tab("Reconstruction"):
|
211 |
# 3D model viewer to display the reconstructed model
|
212 |
output_model = gr.Model3D(
|
213 |
-
height=512,
|
214 |
label="Output Model",
|
215 |
-
interactive=False
|
216 |
)
|
217 |
|
218 |
# Define the workflow for the Generate button
|
219 |
-
submit.click(fn=check_input_image, inputs=[
|
220 |
fn=preprocess,
|
221 |
-
inputs=[
|
222 |
-
outputs=[
|
223 |
).success(
|
224 |
fn=reconstruct_and_export,
|
225 |
-
inputs=[
|
226 |
-
outputs=[output_model],
|
227 |
)
|
228 |
|
229 |
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|
|
|
9 |
import torchvision.transforms.functional as TTF
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import numpy as np
|
|
|
12 |
|
13 |
from networks.gaussian_predictor import GaussianPredictor
|
14 |
from util.vis3d import save_ply
|
|
|
54 |
to_tensor = TT.ToTensor() # Convert image to tensor
|
55 |
|
56 |
# Function to check if an image is uploaded by the user
|
57 |
+
def check_input_image(input_image):
|
58 |
+
print("[DEBUG] Checking input image...")
|
59 |
+
if input_image is None:
|
60 |
+
print("[ERROR] No image uploaded!")
|
61 |
+
raise gr.Error("No image uploaded!")
|
62 |
+
print("[INFO] Input image is valid.")
|
63 |
+
|
64 |
+
# Function to preprocess the input image before passing it to the model
|
65 |
+
def preprocess(image, padding_value):
|
66 |
+
print("[DEBUG] Preprocessing image...")
|
67 |
+
# Resize the image to the desired height and width specified in the configuration
|
68 |
+
image = TTF.resize(
|
69 |
+
image, (cfg.dataset.height, cfg.dataset.width),
|
70 |
+
interpolation=TT.InterpolationMode.BICUBIC
|
71 |
+
)
|
72 |
+
# Apply padding to the image
|
73 |
+
pad_border_fn = TT.Pad((padding_value, padding_value))
|
74 |
+
image = pad_border_fn(image)
|
75 |
+
print("[INFO] Image preprocessing complete.")
|
76 |
+
return image
|
77 |
+
|
78 |
+
# Function to reconstruct the 3D model from the input image and export it as a PLY file
|
79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
80 |
+
def reconstruct_and_export(image, num_gauss):
|
81 |
"""
|
82 |
+
Passes image through model, outputs reconstruction in form of a dict of tensors.
|
83 |
"""
|
84 |
print("[DEBUG] Starting reconstruction and export...")
|
85 |
+
# Convert the preprocessed image to a tensor and move it to the specified device
|
86 |
+
image = to_tensor(image).to(device).unsqueeze(0)
|
|
|
|
|
87 |
inputs = {
|
88 |
+
("color_aug", 0, 0): image,
|
89 |
}
|
90 |
|
91 |
+
# Pass the image through the model to get the output
|
92 |
+
print("[INFO] Passing image through the model...")
|
93 |
+
outputs = model(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# Export the reconstruction to a PLY file
|
96 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
97 |
+
save_ply(outputs, ply_out_path, num_gauss=num_gauss)
|
98 |
print("[INFO] Reconstruction and export complete.")
|
99 |
|
100 |
+
return ply_out_path
|
101 |
|
102 |
# Path to save the output PLY file
|
103 |
ply_out_path = f'./mesh.ply'
|
|
|
120 |
with gr.Row(variant="panel"):
|
121 |
with gr.Column(scale=1):
|
122 |
with gr.Row():
|
123 |
+
# Input image component for the user to upload an image
|
124 |
+
input_image = gr.Image(
|
125 |
+
label="Input Image",
|
126 |
+
image_mode="RGBA",
|
127 |
+
sources="upload",
|
128 |
+
type="pil",
|
129 |
+
elem_id="content_image",
|
|
|
|
|
130 |
)
|
131 |
with gr.Row():
|
132 |
# Sliders for configurable parameters
|
133 |
+
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
|
134 |
+
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
|
135 |
with gr.Row():
|
136 |
# Button to trigger the generation process
|
137 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
|
|
147 |
'./demo_examples/re10k_05.jpg',
|
148 |
'./demo_examples/re10k_06.jpg',
|
149 |
],
|
150 |
+
inputs=[input_image],
|
151 |
cache_examples=False,
|
152 |
+
label="Examples",
|
153 |
examples_per_page=20,
|
154 |
)
|
155 |
|
156 |
with gr.Row():
|
157 |
+
# Display the preprocessed image (after resizing and padding)
|
158 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
159 |
|
160 |
with gr.Column(scale=2):
|
161 |
with gr.Row():
|
162 |
with gr.Tab("Reconstruction"):
|
163 |
# 3D model viewer to display the reconstructed model
|
164 |
output_model = gr.Model3D(
|
165 |
+
height=512,
|
166 |
label="Output Model",
|
167 |
+
interactive=False
|
168 |
)
|
169 |
|
170 |
# Define the workflow for the Generate button
|
171 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
172 |
fn=preprocess,
|
173 |
+
inputs=[input_image, padding_value],
|
174 |
+
outputs=[processed_image],
|
175 |
).success(
|
176 |
fn=reconstruct_and_export,
|
177 |
+
inputs=[processed_image, num_gauss],
|
178 |
+
outputs=[output_model],
|
179 |
)
|
180 |
|
181 |
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|