Unique3D / gradio_app /gradio_3dgen.py
Fabrice-TIERCELIN's picture
This PR orders the components more naturally
ae7a0f2 verified
raw
history blame
4.15 kB
import spaces
import os
import gradio as gr
from PIL import Image
from pytorch3d.structures import Meshes
from gradio_app.utils import clean_up
from gradio_app.custom_models.mvimg_prediction import run_mvprediction
from gradio_app.custom_models.normal_prediction import predict_normals
from scripts.refine_lr_to_sr import run_sr_fast
from scripts.utils import save_obj_and_video
from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3
@spaces.GPU(duration=200)
def run_mv(preview_img, input_processing, seed):
if preview_img.size[0] <= 512:
preview_img = run_sr_fast([preview_img])[0]
rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
return rgb_pils, front_pil
@spaces.GPU(duration=200) # seems split into multiple part will leads to `RuntimeError`, before fix it, still initialize here
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
if preview_img is None:
raise gr.Error("The input image is none!")
if isinstance(preview_img, str):
preview_img = Image.open(preview_img)
rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)
vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
meshes = geo_reconstruct_part2(vertices, faces)
new_meshes = geo_reconstruct_part3(meshes, img_list)
vertices = new_meshes.verts_packed()
vertices = vertices / 2 * 1.35
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
ret_mesh, video = save_obj_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
return ret_mesh, video
#######################################
def create_ui(concurrency_id="wkl"):
with gr.Row():
# Inputs
with gr.Column(scale=1):
input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
input_processing = gr.Checkbox(
value=True,
label='Remove Background',
visible=True,
)
do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
fullrunv2_btn = gr.Button('Generate 3D', interactive=True, variant="primary")
# Outputs
with gr.Column(scale=1):
# export mesh display
output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
render_video = gr.Checkbox(value=False, visible=False, label="generate video")
example_folder = os.path.join(os.path.dirname(__file__), "./examples")
example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
gr.Examples(
examples=example_fns,
inputs=[input_image],
cache_examples=False,
label='Examples',
examples_per_page=12
)
fullrunv2_btn.click(
fn = generate3dv2,
inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
outputs=[output_mesh, output_video],
concurrency_id=concurrency_id,
api_name="generate3dv2",
).success(clean_up, api_name=False)
return input_image