|
import sys |
|
import spaces |
|
sys.path.append("flash3d") |
|
|
|
from omegaconf import OmegaConf |
|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as TT |
|
import torchvision.transforms.functional as TTF |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
|
|
from networks.gaussian_predictor import GaussianPredictor |
|
from util.vis3d import save_ply |
|
|
|
def main(): |
|
print("[INFO] Starting main function...") |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("[INFO] CUDA is available. Using GPU device.") |
|
else: |
|
device = "cpu" |
|
print("[INFO] CUDA is not available. Using CPU device.") |
|
|
|
|
|
print("[INFO] Downloading model configuration...") |
|
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml") |
|
print("[INFO] Downloading model weights...") |
|
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth") |
|
|
|
|
|
print("[INFO] Loading model configuration...") |
|
cfg = OmegaConf.load(model_cfg_path) |
|
|
|
|
|
print("[INFO] Initializing GaussianPredictor model...") |
|
model = GaussianPredictor(cfg) |
|
device = torch.device(device) |
|
model.to(device) |
|
|
|
|
|
print("[INFO] Loading model weights...") |
|
model.load_model(model_path) |
|
|
|
|
|
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) |
|
to_tensor = TT.ToTensor() |
|
|
|
|
|
def check_input_image(input_image): |
|
print("[DEBUG] Checking input image...") |
|
if input_image is None: |
|
print("[ERROR] No image uploaded!") |
|
raise gr.Error("No image uploaded!") |
|
print("[INFO] Input image is valid.") |
|
|
|
|
|
def preprocess(image): |
|
print("[DEBUG] Preprocessing image...") |
|
|
|
image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC) |
|
|
|
image = pad_border_fn(image) |
|
print("[INFO] Image preprocessing complete.") |
|
return image |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def reconstruct_and_export(image, num_gauss, batch_size, num_iterations): |
|
print("[DEBUG] Starting reconstruction and export...") |
|
|
|
image = to_tensor(image).to(device).unsqueeze(0) |
|
inputs = {("color_aug", 0, 0): image} |
|
|
|
|
|
model.cfg.dataset.batch_size = batch_size |
|
model.cfg.training.num_iterations = num_iterations |
|
|
|
|
|
print("[INFO] Passing image through the model...") |
|
outputs = model(inputs) |
|
|
|
|
|
print(f"[INFO] Saving output to {ply_out_path}...") |
|
save_ply(outputs, ply_out_path, num_gauss=num_gauss) |
|
print("[INFO] Reconstruction and export complete.") |
|
|
|
return ply_out_path |
|
|
|
|
|
ply_out_path = f'./mesh.ply' |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("# Flash3D") |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
|
|
input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image") |
|
with gr.Row(): |
|
|
|
submit = gr.Button("Generate", elem_id="generate", variant="primary") |
|
with gr.Row(variant="panel"): |
|
|
|
gr.Examples( |
|
examples=[ |
|
'./demo_examples/bedroom_01.png', |
|
'./demo_examples/kitti_02.png', |
|
'./demo_examples/kitti_03.png', |
|
'./demo_examples/re10k_04.jpg', |
|
'./demo_examples/re10k_05.jpg', |
|
'./demo_examples/re10k_06.jpg', |
|
], |
|
inputs=[input_image], |
|
cache_examples=False, |
|
label="Examples", |
|
examples_per_page=20, |
|
) |
|
with gr.Row(): |
|
|
|
processed_image = gr.Image(label="Processed Image", interactive=False) |
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Tab("Reconstruction"): |
|
|
|
output_model = gr.Model3D(height=512, label="Output Model", interactive=False) |
|
with gr.Row(): |
|
num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2) |
|
batch_size = gr.Slider(minimum=1, maximum=32, step=1, label="Batch Size", value=1) |
|
num_iterations = gr.Slider(minimum=1, maximum=1000, step=10, label="Number of Iterations", value=100) |
|
|
|
|
|
submit.click(fn=check_input_image, inputs=[input_image]).success( |
|
fn=preprocess, |
|
inputs=[input_image], |
|
outputs=[processed_image], |
|
).success( |
|
fn=reconstruct_and_export, |
|
inputs=[processed_image, num_gauss, batch_size, num_iterations], |
|
outputs=[output_model], |
|
) |
|
|
|
|
|
demo.queue(max_size=1) |
|
print("[INFO] Launching Gradio demo...") |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
print("[INFO] Running application...") |
|
main() |