|
import sys |
|
import spaces |
|
sys.path.append("flash3d") |
|
|
|
from omegaconf import OmegaConf |
|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as TT |
|
import torchvision.transforms.functional as TTF |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
|
|
from networks.gaussian_predictor import GaussianPredictor |
|
from util.vis3d import save_ply |
|
|
|
def main(): |
|
print("[INFO] Starting main function...") |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("[INFO] CUDA is available. Using GPU device.") |
|
else: |
|
device = "cpu" |
|
print("[INFO] CUDA is not available. Using CPU device.") |
|
|
|
|
|
print("[INFO] Downloading model configuration...") |
|
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="config_re10k_v1.yaml") |
|
print("[INFO] Downloading model weights...") |
|
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="model_re10k_v1.pth") |
|
|
|
|
|
print("[INFO] Loading model configuration...") |
|
cfg = OmegaConf.load(model_cfg_path) |
|
|
|
|
|
print("[INFO] Initializing GaussianPredictor model...") |
|
model = GaussianPredictor(cfg) |
|
device = torch.device(device) |
|
model.to(device) |
|
|
|
|
|
print("[INFO] Loading model weights...") |
|
model.load_model(model_path) |
|
|
|
|
|
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) |
|
to_tensor = TT.ToTensor() |
|
|
|
|
|
def check_input_image(input_image): |
|
print("[DEBUG] Checking input image...") |
|
if input_image is None: |
|
print("[ERROR] No image uploaded!") |
|
raise gr.Error("No image uploaded!") |
|
print("[INFO] Input image is valid.") |
|
|
|
|
|
def preprocess(image): |
|
print("[DEBUG] Preprocessing image...") |
|
|
|
image = TTF.resize( |
|
image, (cfg.dataset.height, cfg.dataset.width), |
|
interpolation=TT.InterpolationMode.BICUBIC |
|
) |
|
|
|
image = pad_border_fn(image) |
|
print("[INFO] Image preprocessing complete.") |
|
return image |
|
|
|
|
|
import sys |
|
import spaces |
|
sys.path.append("flash3d") |
|
|
|
from omegaconf import OmegaConf |
|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as TT |
|
import torchvision.transforms.functional as TTF |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
|
|
from networks.gaussian_predictor import GaussianPredictor |
|
from util.vis3d import save_ply |
|
|
|
def main(): |
|
print("[INFO] Starting main function...") |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("[INFO] CUDA is available. Using GPU device.") |
|
else: |
|
device = "cpu" |
|
print("[INFO] CUDA is not available. Using CPU device.") |
|
|
|
|
|
print("[INFO] Downloading model configuration...") |
|
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="config_re10k_v1.yaml") |
|
print("[INFO] Downloading model weights...") |
|
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="model_re10k_v1.pth") |
|
|
|
|
|
print("[INFO] Loading model configuration...") |
|
cfg = OmegaConf.load(model_cfg_path) |
|
|
|
|
|
print("[INFO] Initializing GaussianPredictor model...") |
|
model = GaussianPredictor(cfg) |
|
device = torch.device(device) |
|
model.to(device) |
|
|
|
|
|
print("[INFO] Loading model weights...") |
|
model.load_model(model_path) |
|
|
|
|
|
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) |
|
to_tensor = TT.ToTensor() |
|
|
|
|
|
def check_input_image(input_image): |
|
print("[DEBUG] Checking input image...") |
|
if input_image is None: |
|
print("[ERROR] No image uploaded!") |
|
raise gr.Error("No image uploaded!") |
|
print("[INFO] Input image is valid.") |
|
|
|
|
|
def preprocess(image): |
|
print("[DEBUG] Preprocessing image...") |
|
|
|
image = TTF.resize( |
|
image, (cfg.dataset.height, cfg.dataset.width), |
|
interpolation=TT.InterpolationMode.BICUBIC |
|
) |
|
|
|
image = pad_border_fn(image) |
|
print("[INFO] Image preprocessing complete.") |
|
return image |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def reconstruct_and_export(image): |
|
""" |
|
Passes image through model, outputs reconstruction in form of a dict of tensors. |
|
""" |
|
print("[DEBUG] Starting reconstruction and export...") |
|
|
|
image = to_tensor(image).to(device).unsqueeze(0) |
|
inputs = { |
|
("color_aug", 0, 0): image, |
|
} |
|
|
|
|
|
print("[INFO] Passing image through the model...") |
|
outputs = model(inputs) |
|
|
|
|
|
print(f"[INFO] Saving output to {ply_out_path}...") |
|
save_ply(outputs, ply_out_path, num_gauss=2) |
|
print("[INFO] Reconstruction and export complete.") |
|
|
|
return ply_out_path |
|
|
|
|
|
ply_out_path = f'./mesh.ply' |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Flash3D |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
## Comments: |
|
1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s. |
|
2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show. |
|
3. Known limitations include: |
|
- A black dot appearing on the model from some viewpoints. |
|
- See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes. |
|
- Back of objects are blurry: this is a model limitation due to it being deterministic. |
|
4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run. |
|
## How does it work? |
|
Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image, |
|
in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations. |
|
The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object. |
|
The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention). |
|
The rendering is also very fast, due to using Gaussian Splatting. |
|
Combined, this results in very cheap training and high-quality results. |
|
For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150). |
|
""" |
|
) |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
|
|
input_image = gr.Image( |
|
label="Input Image", |
|
image_mode="RGBA", |
|
sources="upload", |
|
type="pil", |
|
elem_id="content_image", |
|
) |
|
with gr.Row(): |
|
|
|
submit = gr.Button("Generate", elem_id="generate", variant="primary") |
|
|
|
with gr.Row(variant="panel"): |
|
|
|
gr.Examples( |
|
examples=[ |
|
'./demo_examples/bedroom_01.png', |
|
'./demo_examples/kitti_02.png', |
|
'./demo_examples/kitti_03.png', |
|
'./demo_examples/re10k_04.jpg', |
|
'./demo_examples/re10k_05.jpg', |
|
'./demo_examples/re10k_06.jpg', |
|
], |
|
inputs=[input_image], |
|
cache_examples=False, |
|
label="Examples", |
|
examples_per_page=20, |
|
) |
|
|
|
with gr.Row(): |
|
|
|
processed_image = gr.Image(label="Processed Image", interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Tab("Reconstruction"): |
|
|
|
output_model = gr.Model3D( |
|
height=512, |
|
label="Output Model", |
|
interactive=False |
|
) |
|
|
|
|
|
submit.click(fn=check_input_image, inputs=[input_image]).success( |
|
fn=preprocess, |
|
inputs=[input_image], |
|
outputs=[processed_image], |
|
).success( |
|
fn=reconstruct_and_export, |
|
inputs=[processed_image], |
|
outputs=[output_model], |
|
) |
|
|
|
|
|
demo.queue(max_size=1) |
|
print("[INFO] Launching Gradio demo...") |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
print("[INFO] Running application...") |
|
main() |
|
def reconstruct_and_export(image): |
|
""" |
|
Passes image through model, outputs reconstruction in form of a dict of tensors. |
|
""" |
|
print("[DEBUG] Starting reconstruction and export...") |
|
|
|
image = to_tensor(image).to(device).unsqueeze(0) |
|
inputs = { |
|
("color_aug", 0, 0): image, |
|
} |
|
|
|
|
|
print("[INFO] Passing image through the model...") |
|
outputs = model(inputs) |
|
|
|
|
|
print(f"[INFO] Saving output to {ply_out_path}...") |
|
save_ply(outputs, ply_out_path, num_gauss=2) |
|
print("[INFO] Reconstruction and export complete.") |
|
|
|
return ply_out_path |
|
|
|
|
|
ply_out_path = f'./mesh.ply' |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Flash3D |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
## Comments: |
|
1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s. |
|
2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show. |
|
3. Known limitations include: |
|
- A black dot appearing on the model from some viewpoints. |
|
- See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes. |
|
- Back of objects are blurry: this is a model limitation due to it being deterministic. |
|
4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run. |
|
## How does it work? |
|
Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image, |
|
in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations. |
|
The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object. |
|
The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention). |
|
The rendering is also very fast, due to using Gaussian Splatting. |
|
Combined, this results in very cheap training and high-quality results. |
|
For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150). |
|
""" |
|
) |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
|
|
input_image = gr.Image( |
|
label="Input Image", |
|
image_mode="RGBA", |
|
sources="upload", |
|
type="pil", |
|
elem_id="content_image", |
|
) |
|
with gr.Row(): |
|
|
|
submit = gr.Button("Generate", elem_id="generate", variant="primary") |
|
|
|
with gr.Row(variant="panel"): |
|
|
|
gr.Examples( |
|
examples=[ |
|
'./demo_examples/bedroom_01.png', |
|
'./demo_examples/kitti_02.png', |
|
'./demo_examples/kitti_03.png', |
|
'./demo_examples/re10k_04.jpg', |
|
'./demo_examples/re10k_05.jpg', |
|
'./demo_examples/re10k_06.jpg', |
|
], |
|
inputs=[input_image], |
|
cache_examples=False, |
|
label="Examples", |
|
examples_per_page=20, |
|
) |
|
|
|
with gr.Row(): |
|
|
|
processed_image = gr.Image(label="Processed Image", interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Tab("Reconstruction"): |
|
|
|
output_model = gr.Model3D( |
|
height=512, |
|
label="Output Model", |
|
interactive=False |
|
) |
|
|
|
|
|
submit.click(fn=check_input_image, inputs=[input_image]).success( |
|
fn=preprocess, |
|
inputs=[input_image], |
|
outputs=[processed_image], |
|
).success( |
|
fn=reconstruct_and_export, |
|
inputs=[processed_image], |
|
outputs=[output_model], |
|
) |
|
|
|
|
|
demo.queue(max_size=1) |
|
print("[INFO] Launching Gradio demo...") |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
print("[INFO] Running application...") |
|
main() |