File size: 5,811 Bytes
ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e b789e6e acebad3 b789e6e acebad3 ffbcf9e b789e6e ffbcf9e b789e6e ffbcf9e 17ea36c acebad3 17ea36c b789e6e ffbcf9e acebad3 ffbcf9e a321f01 acebad3 a321f01 acebad3 b782b56 a321f01 acebad3 b782b56 a321f01 b789e6e 564492e 6a66177 acebad3 b789e6e a321f01 b789e6e acebad3 b789e6e acebad3 b789e6e a321f01 acebad3 ffbcf9e 17ea36c ffbcf9e a321f01 ffbcf9e a321f01 ffbcf9e a321f01 ffbcf9e acebad3 a321f01 ffbcf9e a321f01 ffbcf9e acebad3 a321f01 ffbcf9e b789e6e acebad3 ffbcf9e b789e6e ffbcf9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import sys
import spaces
sys.path.append("flash3d") # Add the flash3d directory to the system path for importing local modules
from omegaconf import OmegaConf
import gradio as gr
import torch
import torchvision.transforms as TT
import torchvision.transforms.functional as TTF
from huggingface_hub import hf_hub_download
import numpy as np
from networks.gaussian_predictor import GaussianPredictor
from util.vis3d import save_ply
def main():
print("[INFO] Starting main function...")
if torch.cuda.is_available():
device = "cuda:0"
print("[INFO] CUDA is available. Using GPU device.")
else:
device = "cpu"
print("[INFO] CUDA is not available. Using CPU device.")
print("[INFO] Downloading model configuration...")
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
print("[INFO] Downloading model weights...")
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
print("[INFO] Loading model configuration...")
cfg = OmegaConf.load(model_cfg_path)
print("[INFO] Initializing GaussianPredictor model...")
model = GaussianPredictor(cfg)
try:
device = torch.device(device)
model.to(device)
except Exception as e:
print(f"[ERROR] Failed to set device: {e}")
raise
print("[INFO] Loading model weights...")
model.load_model(model_path)
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
to_tensor = TT.ToTensor()
def check_input_image(input_image):
print("[DEBUG] Checking input image...")
if input_image is None:
print("[ERROR] No image uploaded!")
raise gr.Error("No image uploaded!")
print("[INFO] Input image is valid.")
def preprocess(image, padding_value):
print("[DEBUG] Preprocessing image...")
image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
pad_border_fn = TT.Pad((padding_value, padding_value))
image = pad_border_fn(image)
print("[INFO] Image preprocessing complete.")
return image
@spaces.GPU(duration=120)
def reconstruct_and_export(image, num_gauss, max_sh_degree, scaling_modifier):
print("[DEBUG] Starting reconstruction and export...")
image = to_tensor(image).to(device).unsqueeze(0)
inputs = {("color_aug", 0, 0): image}
print("[INFO] Passing image through the model...")
outputs = model(inputs)
gauss_means = outputs[('gauss_means',0, 0)]
if gauss_means.shape[0] % num_gauss != 0:
raise ValueError(f"Shape mismatch: cannot divide axis of length {gauss_means.shape[0]} into chunks of {num_gauss}")
print(f"[INFO] Saving output to {ply_out_path}...")
save_ply(outputs, ply_out_path, num_gauss=num_gauss, max_sh_degree=max_sh_degree, scaling_modifier=scaling_modifier)
print("[INFO] Reconstruction and export complete.")
return ply_out_path
ply_out_path = f'./mesh.ply'
css = """
h1 {
text-align: center;
display:block;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Flash3D")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
with gr.Row():
input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
with gr.Row():
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
max_sh_degree = gr.Slider(minimum=1, maximum=10, step=1, label="Max SH Degree", value=1)
scaling_modifier = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="Scaling Modifier", value=1.0)
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Row(variant="panel"):
gr.Examples(
examples=[
'./demo_examples/bedroom_01.png',
'./demo_examples/kitti_02.png',
'./demo_examples/kitti_03.png',
'./demo_examples/re10k_04.jpg',
'./demo_examples/re10k_05.jpg',
'./demo_examples/re10k_06.jpg',
],
inputs=[input_image],
cache_examples=False,
label="Examples",
examples_per_page=20,
)
with gr.Row():
processed_image = gr.Image(label="Processed Image", interactive=False)
with gr.Column(scale=2):
with gr.Row():
with gr.Tab("Reconstruction"):
output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
submit.click(fn=check_input_image, inputs=[input_image]).success(
fn=preprocess,
inputs=[input_image, padding_value],
outputs=[processed_image],
).success(
fn=reconstruct_and_export,
inputs=[processed_image, num_gauss, max_sh_degree, scaling_modifier],
outputs=[output_model],
)
demo.queue(max_size=1)
print("[INFO] Launching Gradio demo...")
demo.launch(share=True)
if __name__ == "__main__":
print("[INFO] Running application...")
main() |