Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import pytorch_lightning as pl | |
import torch as th | |
import open3d as o3d | |
import numpy as np | |
import trimesh as tm | |
from models.model import Model | |
model = Model() | |
ckpg = th.load("./checkpoints/epoch=99-step=6000.ckpt", map_location=th.device("cpu")) | |
model.load_state_dict(ckpg["state_dict"]) | |
def process_mesh(mesh_file_name): | |
mesh = tm.load_mesh(mesh_file_name) | |
v = th.tensor(mesh.vertices, dtype=th.float) | |
n = th.tensor(mesh.vertex_normals, dtype=th.float) | |
v -= v.min() | |
v /= v.max() | |
v /= 1.2 | |
v += 0.08 | |
with th.no_grad(): | |
v, f, n, _ = model(v.unsqueeze(0), n.unsqueeze(0)) | |
mesh = tm.Trimesh(vertices=v.squeeze(0), | |
faces=f.squeeze(0), | |
vertex_normals=n.squeeze(0)) | |
obj_path = "./sample.obj" | |
mesh.export(obj_path) | |
return obj_path | |
demo = gr.Interface( | |
fn=process_mesh, | |
inputs=gr.Model3D(), | |
outputs=gr.Model3D( | |
clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), | |
examples=[ | |
[os.path.join(os.path.dirname(__file__), "files/bunny_n1_hi_50.obj")], | |
[os.path.join(os.path.dirname(__file__), "files/child_n2_80.obj")], | |
[os.path.join(os.path.dirname(__file__), "files/eight_n3_70.obj")], | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() |