yejunliang23 commited on
Commit
b57e1c8
·
verified ·
1 Parent(s): 9a346a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  import os
3
- import os
4
  os.environ['SPCONV_ALGO'] = 'native'
5
  import spaces
 
6
  import warp as wp
7
  import subprocess
8
  import torch
 
9
  from threading import Thread
10
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
11
  from qwen_vl_utils import process_vision_info
@@ -24,6 +25,8 @@ import open3d as o3d
24
  from huggingface_hub import hf_hub_download
25
  import numpy as np
26
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
27
 
28
  def _remove_image_special(text):
29
  text = text.replace('<ref>', '').replace('</ref>', '')
@@ -66,7 +69,7 @@ def save_ply_from_array(verts):
66
  return tmpf.name
67
 
68
  @spaces.GPU(duration=120)
69
- def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,top_p,temperature):
70
  torch.manual_seed(seed)
71
  chat_query = _chatbot[-1][0]
72
  query = task_history[-1][0]
@@ -74,7 +77,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
74
  if len(chat_query) == 0:
75
  _chatbot.pop()
76
  task_history.pop()
77
- return _chatbot,task_history,viewer_voxel,viewer_mesh,task_new
78
  print("User: " + _parse_text(query))
79
  history_cp = copy.deepcopy(task_history)
80
  full_response = ""
@@ -127,10 +130,10 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
127
  new_text = f"mesh-start\n{new_text}\nmesh-end"
128
  full_response += new_text
129
  _chatbot[-1] = (_parse_text(chat_query), _parse_text(full_response))
130
- yield _chatbot,viewer_voxel,viewer_mesh,task_new
131
 
132
  task_history[-1] = (chat_query, full_response)
133
- yield _chatbot,viewer_voxel,viewer_mesh,task_new
134
 
135
  if encoding_indices is not None:
136
  print("processing mesh...")
@@ -140,7 +143,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
140
  indices = torch.nonzero(z_s[0] == 1)
141
  position_recon= (indices.float() + 0.5) / 64 - 0.5
142
  fig = make_pointcloud_figure(position_recon)
143
- yield _chatbot,fig,viewer_mesh,task_new
144
 
145
  position=position_recon
146
  coords = ((position + 0.5) * 64).int().contiguous()
@@ -158,6 +161,14 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
158
  cond = pipeline_text.get_cond([prompt])
159
  slat = pipeline_text.sample_slat(cond, coords)
160
  outputs = pipeline_text.decode_slat(slat, ['mesh', 'gaussian'])
 
 
 
 
 
 
 
 
161
 
162
  glb = postprocessing_utils.to_glb(
163
  outputs['gaussian'][0],
@@ -168,7 +179,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
168
  )
169
  glb.export(f"temper.glb")
170
  print("processing mesh over...")
171
- yield _chatbot,fig,"temper.glb",task_new
172
  else:
173
  # image to 3d
174
  with torch.no_grad():
@@ -176,6 +187,15 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
176
  cond = pipeline_image.get_cond([img])
177
  slat = pipeline_image.sample_slat(cond, coords)
178
  outputs = pipeline_image.decode_slat(slat, ['mesh', 'gaussian'])
 
 
 
 
 
 
 
 
 
179
  glb = postprocessing_utils.to_glb(
180
  outputs['gaussian'][0],
181
  outputs['mesh'][0],
@@ -185,10 +205,10 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
185
  )
186
  glb.export(f"temper.glb")
187
  print("processing mesh over...")
188
- yield _chatbot,fig,"temper.glb",task_new
189
  except:
190
  print("processing mesh...bug")
191
- yield _chatbot,fig,viewer_mesh,task_new
192
 
193
  def regenerate(_chatbot, task_history):
194
  if not task_history:
@@ -459,7 +479,9 @@ with gr.Blocks() as demo:
459
  task_new = gr.State([])
460
  with gr.Column():
461
  viewer_plot = gr.Plot(label="Voxel Visual",scale=0.5)
462
- viewer_mesh = gr.Model3D(label="Mesh Visual", height=200,scale=1.0)
 
 
463
 
464
  examples_text = gr.Examples(
465
  examples=[
@@ -497,8 +519,8 @@ with gr.Blocks() as demo:
497
 
498
  submit_btn.click(add_text, [chatbot, task_history, query,task_new],\
499
  [chatbot, task_history,task_new]).then(
500
- predict, [chatbot, task_history,viewer_plot,viewer_mesh,task_new,seed,top_k,top_p,temperature],\
501
- [chatbot,viewer_plot,viewer_mesh,task_new], show_progress=True
502
  )
503
  submit_btn.click(reset_user_input, [], [query])
504
  empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
 
1
  import gradio as gr
2
  import os
 
3
  os.environ['SPCONV_ALGO'] = 'native'
4
  import spaces
5
+ from gradio_litmodel3d import LitModel3D
6
  import warp as wp
7
  import subprocess
8
  import torch
9
+ import uuid
10
  from threading import Thread
11
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
12
  from qwen_vl_utils import process_vision_info
 
25
  from huggingface_hub import hf_hub_download
26
  import numpy as np
27
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
28
+ TMP_DIR = "/tmp/Trellis-demo"
29
+ os.makedirs(TMP_DIR, exist_ok=True)
30
 
31
  def _remove_image_special(text):
32
  text = text.replace('<ref>', '').replace('</ref>', '')
 
69
  return tmpf.name
70
 
71
  @spaces.GPU(duration=120)
72
+ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,top_p,temperature,video_path):
73
  torch.manual_seed(seed)
74
  chat_query = _chatbot[-1][0]
75
  query = task_history[-1][0]
 
77
  if len(chat_query) == 0:
78
  _chatbot.pop()
79
  task_history.pop()
80
+ return _chatbot,task_history,viewer_voxel,viewer_mesh,task_new,video_path
81
  print("User: " + _parse_text(query))
82
  history_cp = copy.deepcopy(task_history)
83
  full_response = ""
 
130
  new_text = f"mesh-start\n{new_text}\nmesh-end"
131
  full_response += new_text
132
  _chatbot[-1] = (_parse_text(chat_query), _parse_text(full_response))
133
+ yield _chatbot,viewer_voxel,viewer_mesh,task_new,video_path
134
 
135
  task_history[-1] = (chat_query, full_response)
136
+ yield _chatbot,viewer_voxel,viewer_mesh,task_new,video_path
137
 
138
  if encoding_indices is not None:
139
  print("processing mesh...")
 
143
  indices = torch.nonzero(z_s[0] == 1)
144
  position_recon= (indices.float() + 0.5) / 64 - 0.5
145
  fig = make_pointcloud_figure(position_recon)
146
+ yield _chatbot,fig,viewer_mesh,task_new,video_path
147
 
148
  position=position_recon
149
  coords = ((position + 0.5) * 64).int().contiguous()
 
161
  cond = pipeline_text.get_cond([prompt])
162
  slat = pipeline_text.sample_slat(cond, coords)
163
  outputs = pipeline_text.decode_slat(slat, ['mesh', 'gaussian'])
164
+
165
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
166
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
167
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
168
+ trial_id = uuid.uuid4()
169
+ video_path = f"{TMP_DIR}/{trial_id}.mp4"
170
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
171
+ imageio.mimsave(video_path, video, fps=15)
172
 
173
  glb = postprocessing_utils.to_glb(
174
  outputs['gaussian'][0],
 
179
  )
180
  glb.export(f"temper.glb")
181
  print("processing mesh over...")
182
+ yield _chatbot,fig,"temper.glb",task_new,video_path
183
  else:
184
  # image to 3d
185
  with torch.no_grad():
 
187
  cond = pipeline_image.get_cond([img])
188
  slat = pipeline_image.sample_slat(cond, coords)
189
  outputs = pipeline_image.decode_slat(slat, ['mesh', 'gaussian'])
190
+
191
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
192
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
193
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
194
+ trial_id = uuid.uuid4()
195
+ video_path = f"{TMP_DIR}/{trial_id}.mp4"
196
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
197
+ imageio.mimsave(video_path, video, fps=15)
198
+
199
  glb = postprocessing_utils.to_glb(
200
  outputs['gaussian'][0],
201
  outputs['mesh'][0],
 
205
  )
206
  glb.export(f"temper.glb")
207
  print("processing mesh over...")
208
+ yield _chatbot,fig,"temper.glb",task_new,video_path,video_path
209
  except:
210
  print("processing mesh...bug")
211
+ yield _chatbot,fig,viewer_mesh,task_new,video_path
212
 
213
  def regenerate(_chatbot, task_history):
214
  if not task_history:
 
479
  task_new = gr.State([])
480
  with gr.Column():
481
  viewer_plot = gr.Plot(label="Voxel Visual",scale=0.5)
482
+ #viewer_mesh = gr.Model3D(label="Mesh Visual", height=200,scale=1.0)
483
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
484
+ viewer_mesh = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
485
 
486
  examples_text = gr.Examples(
487
  examples=[
 
519
 
520
  submit_btn.click(add_text, [chatbot, task_history, query,task_new],\
521
  [chatbot, task_history,task_new]).then(
522
+ predict, [chatbot, task_history,viewer_plot,viewer_mesh,task_new,seed,top_k,top_p,temperature,video_output],\
523
+ [chatbot,viewer_plot,viewer_mesh,task_new,video_output], show_progress=True
524
  )
525
  submit_btn.click(reset_user_input, [], [query])
526
  empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)