Flash3d / app.py
Ryukijano's picture
Update app.py
53f1f2c verified
raw
history blame
4.58 kB
def main():
print("[INFO] Starting main function...")
if torch.cuda.is_available():
device = "cuda:0"
print("[INFO] CUDA is available. Using GPU device.")
else:
device = "cpu"
print("[INFO] CUDA is not available. Using CPU device.")
print("[INFO] Downloading model configuration...")
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
print("[INFO] Downloading model weights...")
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
print("[INFO] Loading model configuration...")
cfg = OmegaConf.load(model_cfg_path)
print("[INFO] Initializing GaussianPredictor model...")
model = GaussianPredictor(cfg)
device = torch.device(device)
model.to(device)
print("[INFO] Loading model weights...")
model.load_model(model_path)
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
to_tensor = TT.ToTensor()
def check_input_image(input_image):
print("[DEBUG] Checking input image...")
if input_image is None:
print("[ERROR] No image uploaded!")
raise gr.Error("No image uploaded!")
print("[INFO] Input image is valid.")
def preprocess(image, resolution):
print("[DEBUG] Preprocessing image...")
image = TTF.resize(image, (resolution, resolution), interpolation=TT.InterpolationMode.BICUBIC)
image = pad_border_fn(image)
print("[INFO] Image preprocessing complete.")
return image
@spaces.GPU(duration=120)
def reconstruct_and_export(image, num_gauss):
print("[DEBUG] Starting reconstruction and export...")
image = to_tensor(image).to(device).unsqueeze(0)
inputs = {("color_aug", 0, 0): image}
print("[INFO] Passing image through the model...")
outputs = model(inputs)
print(f"[INFO] Saving output to {ply_out_path}...")
save_ply(outputs, ply_out_path, num_gauss=num_gauss)
print("[INFO] Reconstruction and export complete.")
return ply_out_path
ply_out_path = f'./mesh.ply'
css = """
h1 {
text-align: center;
display:block;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Flash3D")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
with gr.Row():
input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Row(variant="panel"):
gr.Examples(
examples=[
'./demo_examples/bedroom_01.png',
'./demo_examples/kitti_02.png',
'./demo_examples/kitti_03.png',
'./demo_examples/re10k_04.jpg',
'./demo_examples/re10k_05.jpg',
'./demo_examples/re10k_06.jpg',
],
inputs=[input_image],
cache_examples=False,
label="Examples",
examples_per_page=20,
)
with gr.Row():
processed_image = gr.Image(label="Processed Image", interactive=False)
with gr.Column(scale=2):
with gr.Row():
with gr.Tab("Reconstruction"):
output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
with gr.Row():
resolution = gr.Slider(minimum=256, maximum=1024, step=64, label="Image Resolution", value=cfg.dataset.height)
num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2)
submit.click(fn=check_input_image, inputs=[input_image]).success(
fn=preprocess,
inputs=[input_image, resolution],
outputs=[processed_image],
).success(
fn=reconstruct_and_export,
inputs=[processed_image, num_gauss],
outputs=[output_model],
)
demo.queue(max_size=1)
print("[INFO] Launching Gradio demo...")
demo.launch(share=True)
if __name__ == "__main__":
print("[INFO] Running application...")
main()