Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
-
import os
|
4 |
os.environ['SPCONV_ALGO'] = 'native'
|
5 |
import spaces
|
|
|
6 |
import warp as wp
|
7 |
import subprocess
|
8 |
import torch
|
|
|
9 |
from threading import Thread
|
10 |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
|
11 |
from qwen_vl_utils import process_vision_info
|
@@ -24,6 +25,8 @@ import open3d as o3d
|
|
24 |
from huggingface_hub import hf_hub_download
|
25 |
import numpy as np
|
26 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
|
|
27 |
|
28 |
def _remove_image_special(text):
|
29 |
text = text.replace('<ref>', '').replace('</ref>', '')
|
@@ -66,7 +69,7 @@ def save_ply_from_array(verts):
|
|
66 |
return tmpf.name
|
67 |
|
68 |
@spaces.GPU(duration=120)
|
69 |
-
def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,top_p,temperature):
|
70 |
torch.manual_seed(seed)
|
71 |
chat_query = _chatbot[-1][0]
|
72 |
query = task_history[-1][0]
|
@@ -74,7 +77,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
74 |
if len(chat_query) == 0:
|
75 |
_chatbot.pop()
|
76 |
task_history.pop()
|
77 |
-
return _chatbot,task_history,viewer_voxel,viewer_mesh,task_new
|
78 |
print("User: " + _parse_text(query))
|
79 |
history_cp = copy.deepcopy(task_history)
|
80 |
full_response = ""
|
@@ -127,10 +130,10 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
127 |
new_text = f"mesh-start\n{new_text}\nmesh-end"
|
128 |
full_response += new_text
|
129 |
_chatbot[-1] = (_parse_text(chat_query), _parse_text(full_response))
|
130 |
-
yield _chatbot,viewer_voxel,viewer_mesh,task_new
|
131 |
|
132 |
task_history[-1] = (chat_query, full_response)
|
133 |
-
yield _chatbot,viewer_voxel,viewer_mesh,task_new
|
134 |
|
135 |
if encoding_indices is not None:
|
136 |
print("processing mesh...")
|
@@ -140,7 +143,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
140 |
indices = torch.nonzero(z_s[0] == 1)
|
141 |
position_recon= (indices.float() + 0.5) / 64 - 0.5
|
142 |
fig = make_pointcloud_figure(position_recon)
|
143 |
-
yield _chatbot,fig,viewer_mesh,task_new
|
144 |
|
145 |
position=position_recon
|
146 |
coords = ((position + 0.5) * 64).int().contiguous()
|
@@ -158,6 +161,14 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
158 |
cond = pipeline_text.get_cond([prompt])
|
159 |
slat = pipeline_text.sample_slat(cond, coords)
|
160 |
outputs = pipeline_text.decode_slat(slat, ['mesh', 'gaussian'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
glb = postprocessing_utils.to_glb(
|
163 |
outputs['gaussian'][0],
|
@@ -168,7 +179,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
168 |
)
|
169 |
glb.export(f"temper.glb")
|
170 |
print("processing mesh over...")
|
171 |
-
yield _chatbot,fig,"temper.glb",task_new
|
172 |
else:
|
173 |
# image to 3d
|
174 |
with torch.no_grad():
|
@@ -176,6 +187,15 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
176 |
cond = pipeline_image.get_cond([img])
|
177 |
slat = pipeline_image.sample_slat(cond, coords)
|
178 |
outputs = pipeline_image.decode_slat(slat, ['mesh', 'gaussian'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
glb = postprocessing_utils.to_glb(
|
180 |
outputs['gaussian'][0],
|
181 |
outputs['mesh'][0],
|
@@ -185,10 +205,10 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
|
|
185 |
)
|
186 |
glb.export(f"temper.glb")
|
187 |
print("processing mesh over...")
|
188 |
-
yield _chatbot,fig,"temper.glb",task_new
|
189 |
except:
|
190 |
print("processing mesh...bug")
|
191 |
-
yield _chatbot,fig,viewer_mesh,task_new
|
192 |
|
193 |
def regenerate(_chatbot, task_history):
|
194 |
if not task_history:
|
@@ -459,7 +479,9 @@ with gr.Blocks() as demo:
|
|
459 |
task_new = gr.State([])
|
460 |
with gr.Column():
|
461 |
viewer_plot = gr.Plot(label="Voxel Visual",scale=0.5)
|
462 |
-
viewer_mesh = gr.Model3D(label="Mesh Visual", height=200,scale=1.0)
|
|
|
|
|
463 |
|
464 |
examples_text = gr.Examples(
|
465 |
examples=[
|
@@ -497,8 +519,8 @@ with gr.Blocks() as demo:
|
|
497 |
|
498 |
submit_btn.click(add_text, [chatbot, task_history, query,task_new],\
|
499 |
[chatbot, task_history,task_new]).then(
|
500 |
-
predict, [chatbot, task_history,viewer_plot,viewer_mesh,task_new,seed,top_k,top_p,temperature],\
|
501 |
-
[chatbot,viewer_plot,viewer_mesh,task_new], show_progress=True
|
502 |
)
|
503 |
submit_btn.click(reset_user_input, [], [query])
|
504 |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
|
|
3 |
os.environ['SPCONV_ALGO'] = 'native'
|
4 |
import spaces
|
5 |
+
from gradio_litmodel3d import LitModel3D
|
6 |
import warp as wp
|
7 |
import subprocess
|
8 |
import torch
|
9 |
+
import uuid
|
10 |
from threading import Thread
|
11 |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
|
12 |
from qwen_vl_utils import process_vision_info
|
|
|
25 |
from huggingface_hub import hf_hub_download
|
26 |
import numpy as np
|
27 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
28 |
+
TMP_DIR = "/tmp/Trellis-demo"
|
29 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
30 |
|
31 |
def _remove_image_special(text):
|
32 |
text = text.replace('<ref>', '').replace('</ref>', '')
|
|
|
69 |
return tmpf.name
|
70 |
|
71 |
@spaces.GPU(duration=120)
|
72 |
+
def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,top_p,temperature,video_path):
|
73 |
torch.manual_seed(seed)
|
74 |
chat_query = _chatbot[-1][0]
|
75 |
query = task_history[-1][0]
|
|
|
77 |
if len(chat_query) == 0:
|
78 |
_chatbot.pop()
|
79 |
task_history.pop()
|
80 |
+
return _chatbot,task_history,viewer_voxel,viewer_mesh,task_new,video_path
|
81 |
print("User: " + _parse_text(query))
|
82 |
history_cp = copy.deepcopy(task_history)
|
83 |
full_response = ""
|
|
|
130 |
new_text = f"mesh-start\n{new_text}\nmesh-end"
|
131 |
full_response += new_text
|
132 |
_chatbot[-1] = (_parse_text(chat_query), _parse_text(full_response))
|
133 |
+
yield _chatbot,viewer_voxel,viewer_mesh,task_new,video_path
|
134 |
|
135 |
task_history[-1] = (chat_query, full_response)
|
136 |
+
yield _chatbot,viewer_voxel,viewer_mesh,task_new,video_path
|
137 |
|
138 |
if encoding_indices is not None:
|
139 |
print("processing mesh...")
|
|
|
143 |
indices = torch.nonzero(z_s[0] == 1)
|
144 |
position_recon= (indices.float() + 0.5) / 64 - 0.5
|
145 |
fig = make_pointcloud_figure(position_recon)
|
146 |
+
yield _chatbot,fig,viewer_mesh,task_new,video_path
|
147 |
|
148 |
position=position_recon
|
149 |
coords = ((position + 0.5) * 64).int().contiguous()
|
|
|
161 |
cond = pipeline_text.get_cond([prompt])
|
162 |
slat = pipeline_text.sample_slat(cond, coords)
|
163 |
outputs = pipeline_text.decode_slat(slat, ['mesh', 'gaussian'])
|
164 |
+
|
165 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
166 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
167 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
168 |
+
trial_id = uuid.uuid4()
|
169 |
+
video_path = f"{TMP_DIR}/{trial_id}.mp4"
|
170 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
171 |
+
imageio.mimsave(video_path, video, fps=15)
|
172 |
|
173 |
glb = postprocessing_utils.to_glb(
|
174 |
outputs['gaussian'][0],
|
|
|
179 |
)
|
180 |
glb.export(f"temper.glb")
|
181 |
print("processing mesh over...")
|
182 |
+
yield _chatbot,fig,"temper.glb",task_new,video_path
|
183 |
else:
|
184 |
# image to 3d
|
185 |
with torch.no_grad():
|
|
|
187 |
cond = pipeline_image.get_cond([img])
|
188 |
slat = pipeline_image.sample_slat(cond, coords)
|
189 |
outputs = pipeline_image.decode_slat(slat, ['mesh', 'gaussian'])
|
190 |
+
|
191 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
192 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
193 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
194 |
+
trial_id = uuid.uuid4()
|
195 |
+
video_path = f"{TMP_DIR}/{trial_id}.mp4"
|
196 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
197 |
+
imageio.mimsave(video_path, video, fps=15)
|
198 |
+
|
199 |
glb = postprocessing_utils.to_glb(
|
200 |
outputs['gaussian'][0],
|
201 |
outputs['mesh'][0],
|
|
|
205 |
)
|
206 |
glb.export(f"temper.glb")
|
207 |
print("processing mesh over...")
|
208 |
+
yield _chatbot,fig,"temper.glb",task_new,video_path,video_path
|
209 |
except:
|
210 |
print("processing mesh...bug")
|
211 |
+
yield _chatbot,fig,viewer_mesh,task_new,video_path
|
212 |
|
213 |
def regenerate(_chatbot, task_history):
|
214 |
if not task_history:
|
|
|
479 |
task_new = gr.State([])
|
480 |
with gr.Column():
|
481 |
viewer_plot = gr.Plot(label="Voxel Visual",scale=0.5)
|
482 |
+
#viewer_mesh = gr.Model3D(label="Mesh Visual", height=200,scale=1.0)
|
483 |
+
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
484 |
+
viewer_mesh = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
|
485 |
|
486 |
examples_text = gr.Examples(
|
487 |
examples=[
|
|
|
519 |
|
520 |
submit_btn.click(add_text, [chatbot, task_history, query,task_new],\
|
521 |
[chatbot, task_history,task_new]).then(
|
522 |
+
predict, [chatbot, task_history,viewer_plot,viewer_mesh,task_new,seed,top_k,top_p,temperature,video_output],\
|
523 |
+
[chatbot,viewer_plot,viewer_mesh,task_new,video_output], show_progress=True
|
524 |
)
|
525 |
submit_btn.click(reset_user_input, [], [query])
|
526 |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|