Ryukijano commited on
Commit
b782b56
·
verified ·
1 Parent(s): bfe9b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -3
app.py CHANGED
@@ -78,6 +78,38 @@ def main():
78
  # Function to reconstruct the 3D model from the input images and export it as a PLY file
79
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
  def reconstruct_and_export(images, num_gauss):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  """
82
  Passes images through model, outputs reconstruction in form of a dict of tensors.
83
  """
@@ -137,11 +169,11 @@ def main():
137
  # Input images component for the user to upload multiple images
138
  input_images = gr.Gallery(
139
  label="Input Images",
140
- image_mode="RGBA", # Accept RGBA images
141
  sources="upload", # Allow users to upload images
142
- type="pil", # The images are returned as PIL images
143
  elem_id="content_images",
144
- tool="editor", # Optional, for editing images
145
  # Allow multiple image uploads
146
  )
147
  with gr.Row():
 
78
  # Function to reconstruct the 3D model from the input images and export it as a PLY file
79
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
  def reconstruct_and_export(images, num_gauss):
81
+ """
82
+ Passes a batch of images through the model, outputs reconstruction in the form of a dict of tensors.
83
+ """
84
+ print("[DEBUG] Starting reconstruction and export...")
85
+ # Stack the images along a new dimension to create a batch
86
+ images_batch = torch.stack([to_tensor(image) for image in images]).to(device) # Create a batch of images
87
+
88
+ # Create input dictionary expected by the model
89
+ inputs = {
90
+ ("color_aug", 0, 0): images_batch, # Batch of input images
91
+ }
92
+
93
+ # Pass the batch of images through the model to get the output
94
+ print("[INFO] Passing batch of images through the model...")
95
+ outputs = model(inputs) # Perform inference to get model outputs
96
+
97
+ # Use the first output for illustration (or modify to combine outputs as needed)
98
+ gauss_means = outputs[('gauss_means', 0, 0)]
99
+ if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
100
+ adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
101
+ print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
102
+ num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
103
+
104
+ # Debugging tensor shape
105
+ print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
106
+
107
+ # Export the reconstruction to a PLY file
108
+ print(f"[INFO] Saving output to {ply_out_path}...")
109
+ save_ply(outputs, ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
110
+ print("[INFO] Reconstruction and export complete.")
111
+
112
+ return ply_out_path # Return the path to the saved PLY file
113
  """
114
  Passes images through model, outputs reconstruction in form of a dict of tensors.
115
  """
 
169
  # Input images component for the user to upload multiple images
170
  input_images = gr.Gallery(
171
  label="Input Images",
172
+ # Accept RGBA images
173
  sources="upload", # Allow users to upload images
174
+ # The images are returned as PIL images
175
  elem_id="content_images",
176
+ # Optional, for editing images
177
  # Allow multiple image uploads
178
  )
179
  with gr.Row():