File size: 4,583 Bytes
ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e b789e6e acebad3 b789e6e acebad3 ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e 53f1f2c b789e6e ffbcf9e acebad3 ffbcf9e a321f01 53f1f2c a321f01 53f1f2c a321f01 acebad3 53f1f2c b782b56 a321f01 acebad3 a321f01 6a66177 53f1f2c b789e6e a321f01 b789e6e acebad3 b789e6e acebad3 ffbcf9e 53f1f2c ffbcf9e a321f01 ffbcf9e a321f01 ffbcf9e a321f01 ffbcf9e acebad3 53f1f2c acebad3 a321f01 ffbcf9e 53f1f2c a321f01 ffbcf9e 53f1f2c a321f01 ffbcf9e b789e6e acebad3 ffbcf9e b789e6e ffbcf9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
def main():
print("[INFO] Starting main function...")
if torch.cuda.is_available():
device = "cuda:0"
print("[INFO] CUDA is available. Using GPU device.")
else:
device = "cpu"
print("[INFO] CUDA is not available. Using CPU device.")
print("[INFO] Downloading model configuration...")
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
print("[INFO] Downloading model weights...")
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
print("[INFO] Loading model configuration...")
cfg = OmegaConf.load(model_cfg_path)
print("[INFO] Initializing GaussianPredictor model...")
model = GaussianPredictor(cfg)
device = torch.device(device)
model.to(device)
print("[INFO] Loading model weights...")
model.load_model(model_path)
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
to_tensor = TT.ToTensor()
def check_input_image(input_image):
print("[DEBUG] Checking input image...")
if input_image is None:
print("[ERROR] No image uploaded!")
raise gr.Error("No image uploaded!")
print("[INFO] Input image is valid.")
def preprocess(image, resolution):
print("[DEBUG] Preprocessing image...")
image = TTF.resize(image, (resolution, resolution), interpolation=TT.InterpolationMode.BICUBIC)
image = pad_border_fn(image)
print("[INFO] Image preprocessing complete.")
return image
@spaces.GPU(duration=120)
def reconstruct_and_export(image, num_gauss):
print("[DEBUG] Starting reconstruction and export...")
image = to_tensor(image).to(device).unsqueeze(0)
inputs = {("color_aug", 0, 0): image}
print("[INFO] Passing image through the model...")
outputs = model(inputs)
print(f"[INFO] Saving output to {ply_out_path}...")
save_ply(outputs, ply_out_path, num_gauss=num_gauss)
print("[INFO] Reconstruction and export complete.")
return ply_out_path
ply_out_path = f'./mesh.ply'
css = """
h1 {
text-align: center;
display:block;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Flash3D")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
with gr.Row():
input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Row(variant="panel"):
gr.Examples(
examples=[
'./demo_examples/bedroom_01.png',
'./demo_examples/kitti_02.png',
'./demo_examples/kitti_03.png',
'./demo_examples/re10k_04.jpg',
'./demo_examples/re10k_05.jpg',
'./demo_examples/re10k_06.jpg',
],
inputs=[input_image],
cache_examples=False,
label="Examples",
examples_per_page=20,
)
with gr.Row():
processed_image = gr.Image(label="Processed Image", interactive=False)
with gr.Column(scale=2):
with gr.Row():
with gr.Tab("Reconstruction"):
output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
with gr.Row():
resolution = gr.Slider(minimum=256, maximum=1024, step=64, label="Image Resolution", value=cfg.dataset.height)
num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2)
submit.click(fn=check_input_image, inputs=[input_image]).success(
fn=preprocess,
inputs=[input_image, resolution],
outputs=[processed_image],
).success(
fn=reconstruct_and_export,
inputs=[processed_image, num_gauss],
outputs=[output_model],
)
demo.queue(max_size=1)
print("[INFO] Launching Gradio demo...")
demo.launch(share=True)
if __name__ == "__main__":
print("[INFO] Running application...")
main() |