Update app.py
Browse filesfeat: Add GPU support and model configuration download
- Added GPU support detection and device setting in the main function.
- Implemented downloading of model configuration and weights from Hugging Face Hub.
- Included necessary imports for Gradio, Torch, and other dependencies.
- Added system path for importing local modules from the flash3d directory.
- Improved logging to provide detailed information about the process.
app.py
CHANGED
@@ -92,6 +92,11 @@ def main():
|
|
92 |
print("[INFO] Passing image through the model...")
|
93 |
outputs = model(inputs)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
95 |
# Export the reconstruction to a PLY file
|
96 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
97 |
save_ply(outputs, ply_out_path, num_gauss=num_gauss)
|
|
|
92 |
print("[INFO] Passing image through the model...")
|
93 |
outputs = model(inputs)
|
94 |
|
95 |
+
#Ensure the tensor dimensions are compatible
|
96 |
+
gauss_means = outputs[('gauss_means',0, 0)]
|
97 |
+
if gauss_means.shape[0] % num_gauss != 0:
|
98 |
+
raise ValueError(f"Shape mismatch: cannot divide axis of length {gauss_means.shape[0]} into chunks of {num_gauss}")
|
99 |
+
|
100 |
# Export the reconstruction to a PLY file
|
101 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
102 |
save_ply(outputs, ply_out_path, num_gauss=num_gauss)
|