yejunliang23 commited on
Commit
55a202e
·
verified ·
1 Parent(s): 625315a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -5
app.py CHANGED
@@ -15,7 +15,8 @@ import trimesh
15
  from trimesh.exchange.gltf import export_glb
16
  import tempfile
17
  import copy
18
- import plotly.graph_objs as go
 
19
  from PIL import Image
20
  import plotly.express as px
21
  import random
@@ -283,12 +284,66 @@ def reset_state(task_history):
283
  task_history.clear()
284
  return []
285
 
286
- def make_pointcloud_figure(verts,rotate=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  if rotate:
288
  verts = verts.copy()
289
  verts[:, 0] *= -1.0
290
- print(verts.shape,verts.min(),verts.max())
291
- print(verts)
292
  N = len(verts)
293
  soft_palette = ["#FFEBEE", "#FFF3E0", "#FFFDE7", "#E8F5E9",]
294
  palette = px.colors.qualitative.Set3
@@ -314,7 +369,7 @@ def make_pointcloud_figure(verts,rotate=False):
314
  line=dict(width=1)
315
  )
316
  )
317
- print(scatter)
318
  layout = go.Layout(
319
  width =800,
320
  height=300,
 
15
  from trimesh.exchange.gltf import export_glb
16
  import tempfile
17
  import copy
18
+ #import plotly.graph_objs as go
19
+ import plotly.graph_objects as go
20
  from PIL import Image
21
  import plotly.express as px
22
  import random
 
284
  task_history.clear()
285
  return []
286
 
287
+ def make_pointcloud_figure(verts, rotate=False):
288
+ """
289
+ Simple 3D scatter of point cloud that always shows points.
290
+
291
+ Parameters:
292
+ -------------
293
+ verts : (N, 3) numpy array or torch tensor
294
+ Point cloud coordinates.
295
+ rotate : bool
296
+ If True, reflect X axis.
297
+
298
+ Returns:
299
+ --------
300
+ fig : plotly.graph_objects.Figure
301
+ """
302
+ # Convert to numpy if torch tensor
303
+ try:
304
+ import torch
305
+ if isinstance(verts, torch.Tensor):
306
+ verts = verts.cpu().numpy()
307
+ except ImportError:
308
+ pass
309
+
310
+ if rotate:
311
+ verts = verts.copy()
312
+ verts[:, 0] *= -1.0
313
+
314
+ # Build scatter trace with a fixed color and larger size
315
+ scatter = go.Scatter3d(
316
+ x=verts[:, 0],
317
+ y=verts[:, 1],
318
+ z=verts[:, 2],
319
+ mode='markers',
320
+ marker=dict(
321
+ size=5,
322
+ color='blue',
323
+ opacity=0.8
324
+ )
325
+ )
326
+
327
+ # Default layout: auto ranging, hidden axes, minimal margins
328
+ layout = go.Layout(
329
+ scene=dict(
330
+ xaxis=dict(visible=False),
331
+ yaxis=dict(visible=False),
332
+ zaxis=dict(visible=False),
333
+ aspectmode='auto'
334
+ ),
335
+ margin=dict(l=0, r=0, b=0, t=0)
336
+ )
337
+
338
+ fig = go.Figure(data=[scatter], layout=layout)
339
+ return fig
340
+
341
+
342
+ def make_pointcloud_figure_old(verts,rotate=False):
343
  if rotate:
344
  verts = verts.copy()
345
  verts[:, 0] *= -1.0
346
+
 
347
  N = len(verts)
348
  soft_palette = ["#FFEBEE", "#FFF3E0", "#FFFDE7", "#E8F5E9",]
349
  palette = px.colors.qualitative.Set3
 
369
  line=dict(width=1)
370
  )
371
  )
372
+
373
  layout = go.Layout(
374
  width =800,
375
  height=300,