File size: 5,811 Bytes
ffbcf9e
 
b789e6e
ffbcf9e
 
 
 
 
 
 
b789e6e
ffbcf9e
 
 
 
 
b789e6e
ffbcf9e
 
b789e6e
ffbcf9e
 
b789e6e
ffbcf9e
b789e6e
acebad3
b789e6e
acebad3
ffbcf9e
b789e6e
ffbcf9e
b789e6e
 
ffbcf9e
17ea36c
 
acebad3
17ea36c
 
 
b789e6e
 
ffbcf9e
 
acebad3
 
ffbcf9e
a321f01
 
 
 
 
 
 
 
 
acebad3
a321f01
 
 
 
 
acebad3
 
b782b56
a321f01
acebad3
b782b56
a321f01
 
b789e6e
564492e
 
 
 
6a66177
acebad3
b789e6e
 
a321f01
b789e6e
 
 
 
 
 
 
 
 
 
 
acebad3
b789e6e
 
 
acebad3
b789e6e
a321f01
 
acebad3
 
ffbcf9e
 
 
17ea36c
ffbcf9e
 
 
 
 
 
 
 
 
a321f01
ffbcf9e
a321f01
ffbcf9e
 
 
 
a321f01
ffbcf9e
 
 
 
acebad3
 
a321f01
ffbcf9e
a321f01
 
ffbcf9e
 
acebad3
a321f01
ffbcf9e
 
 
b789e6e
acebad3
ffbcf9e
 
b789e6e
ffbcf9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import sys
import spaces
sys.path.append("flash3d")  # Add the flash3d directory to the system path for importing local modules

from omegaconf import OmegaConf
import gradio as gr
import torch
import torchvision.transforms as TT
import torchvision.transforms.functional as TTF
from huggingface_hub import hf_hub_download
import numpy as np

from networks.gaussian_predictor import GaussianPredictor
from util.vis3d import save_ply

def main():
    print("[INFO] Starting main function...")
    if torch.cuda.is_available():
        device = "cuda:0"
        print("[INFO] CUDA is available. Using GPU device.")
    else:
        device = "cpu"
        print("[INFO] CUDA is not available. Using CPU device.")

    print("[INFO] Downloading model configuration...")
    model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
    print("[INFO] Downloading model weights...")
    model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")

    print("[INFO] Loading model configuration...")
    cfg = OmegaConf.load(model_cfg_path)
    
    print("[INFO] Initializing GaussianPredictor model...")
    model = GaussianPredictor(cfg)
    try:
        device = torch.device(device)
        model.to(device)
    except Exception as e:
        print(f"[ERROR] Failed to set device: {e}")
        raise
    
    print("[INFO] Loading model weights...")
    model.load_model(model_path)

    pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
    to_tensor = TT.ToTensor()

    def check_input_image(input_image):
        print("[DEBUG] Checking input image...")
        if input_image is None:
            print("[ERROR] No image uploaded!")
            raise gr.Error("No image uploaded!")
        print("[INFO] Input image is valid.")

    def preprocess(image, padding_value):
        print("[DEBUG] Preprocessing image...")
        image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
        pad_border_fn = TT.Pad((padding_value, padding_value))
        image = pad_border_fn(image)
        print("[INFO] Image preprocessing complete.")
        return image

    @spaces.GPU(duration=120)
    def reconstruct_and_export(image, num_gauss, max_sh_degree, scaling_modifier):
        print("[DEBUG] Starting reconstruction and export...")
        image = to_tensor(image).to(device).unsqueeze(0)
        inputs = {("color_aug", 0, 0): image}

        print("[INFO] Passing image through the model...")
        outputs = model(inputs)

        gauss_means = outputs[('gauss_means',0, 0)]
        if gauss_means.shape[0] % num_gauss != 0:
            raise ValueError(f"Shape mismatch: cannot divide axis of length {gauss_means.shape[0]} into chunks of {num_gauss}")

        print(f"[INFO] Saving output to {ply_out_path}...")
        save_ply(outputs, ply_out_path, num_gauss=num_gauss, max_sh_degree=max_sh_degree, scaling_modifier=scaling_modifier)
        print("[INFO] Reconstruction and export complete.")

        return ply_out_path
    
    ply_out_path = f'./mesh.ply'

    css = """
        h1 {
            text-align: center;
            display:block;
        }
        """

    with gr.Blocks(css=css) as demo:
        gr.Markdown("# Flash3D")
        with gr.Row(variant="panel"):
            with gr.Column(scale=1):
                with gr.Row():
                    input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
                with gr.Row():
                    num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
                    padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
                    max_sh_degree = gr.Slider(minimum=1, maximum=10, step=1, label="Max SH Degree", value=1)
                    scaling_modifier = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="Scaling Modifier", value=1.0)
                with gr.Row():
                    submit = gr.Button("Generate", elem_id="generate", variant="primary")

                with gr.Row(variant="panel"):
                    gr.Examples(
                        examples=[
                            './demo_examples/bedroom_01.png',
                            './demo_examples/kitti_02.png',
                            './demo_examples/kitti_03.png',
                            './demo_examples/re10k_04.jpg',
                            './demo_examples/re10k_05.jpg',
                            './demo_examples/re10k_06.jpg',
                        ],
                        inputs=[input_image],
                        cache_examples=False,
                        label="Examples",
                        examples_per_page=20,
                    )

                with gr.Row():
                    processed_image = gr.Image(label="Processed Image", interactive=False)

            with gr.Column(scale=2):
                with gr.Row():
                    with gr.Tab("Reconstruction"):
                        output_model = gr.Model3D(height=512, label="Output Model", interactive=False)

        submit.click(fn=check_input_image, inputs=[input_image]).success(
            fn=preprocess,
            inputs=[input_image, padding_value],
            outputs=[processed_image],
        ).success(
            fn=reconstruct_and_export,
            inputs=[processed_image, num_gauss, max_sh_degree, scaling_modifier],
            outputs=[output_model],
        )

    demo.queue(max_size=1)
    print("[INFO] Launching Gradio demo...")
    demo.launch(share=True)

if __name__ == "__main__":
    print("[INFO] Running application...")
    main()