Update app.py
Browse files
app.py
CHANGED
@@ -78,6 +78,38 @@ def main():
|
|
78 |
# Function to reconstruct the 3D model from the input images and export it as a PLY file
|
79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
80 |
def reconstruct_and_export(images, num_gauss):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
"""
|
82 |
Passes images through model, outputs reconstruction in form of a dict of tensors.
|
83 |
"""
|
@@ -137,11 +169,11 @@ def main():
|
|
137 |
# Input images component for the user to upload multiple images
|
138 |
input_images = gr.Gallery(
|
139 |
label="Input Images",
|
140 |
-
|
141 |
sources="upload", # Allow users to upload images
|
142 |
-
|
143 |
elem_id="content_images",
|
144 |
-
|
145 |
# Allow multiple image uploads
|
146 |
)
|
147 |
with gr.Row():
|
|
|
78 |
# Function to reconstruct the 3D model from the input images and export it as a PLY file
|
79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
80 |
def reconstruct_and_export(images, num_gauss):
|
81 |
+
"""
|
82 |
+
Passes a batch of images through the model, outputs reconstruction in the form of a dict of tensors.
|
83 |
+
"""
|
84 |
+
print("[DEBUG] Starting reconstruction and export...")
|
85 |
+
# Stack the images along a new dimension to create a batch
|
86 |
+
images_batch = torch.stack([to_tensor(image) for image in images]).to(device) # Create a batch of images
|
87 |
+
|
88 |
+
# Create input dictionary expected by the model
|
89 |
+
inputs = {
|
90 |
+
("color_aug", 0, 0): images_batch, # Batch of input images
|
91 |
+
}
|
92 |
+
|
93 |
+
# Pass the batch of images through the model to get the output
|
94 |
+
print("[INFO] Passing batch of images through the model...")
|
95 |
+
outputs = model(inputs) # Perform inference to get model outputs
|
96 |
+
|
97 |
+
# Use the first output for illustration (or modify to combine outputs as needed)
|
98 |
+
gauss_means = outputs[('gauss_means', 0, 0)]
|
99 |
+
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
|
100 |
+
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
|
101 |
+
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
|
102 |
+
num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
|
103 |
+
|
104 |
+
# Debugging tensor shape
|
105 |
+
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
|
106 |
+
|
107 |
+
# Export the reconstruction to a PLY file
|
108 |
+
print(f"[INFO] Saving output to {ply_out_path}...")
|
109 |
+
save_ply(outputs, ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
|
110 |
+
print("[INFO] Reconstruction and export complete.")
|
111 |
+
|
112 |
+
return ply_out_path # Return the path to the saved PLY file
|
113 |
"""
|
114 |
Passes images through model, outputs reconstruction in form of a dict of tensors.
|
115 |
"""
|
|
|
169 |
# Input images component for the user to upload multiple images
|
170 |
input_images = gr.Gallery(
|
171 |
label="Input Images",
|
172 |
+
# Accept RGBA images
|
173 |
sources="upload", # Allow users to upload images
|
174 |
+
# The images are returned as PIL images
|
175 |
elem_id="content_images",
|
176 |
+
# Optional, for editing images
|
177 |
# Allow multiple image uploads
|
178 |
)
|
179 |
with gr.Row():
|