huathedev commited on
Commit
af43bf0
·
1 Parent(s): 2848e6d

Update pages/01_🦷 Segment.py

Browse files
Files changed (1) hide show
  1. pages/01_🦷 Segment.py +131 -157
pages/01_🦷 Segment.py CHANGED
@@ -1,27 +1,29 @@
1
- from streamlit import session_state as session
2
  import shutil
 
3
 
4
- import os
5
  import numpy as np
6
- from sklearn import neighbors
7
  from scipy.spatial import distance_matrix
 
8
  from pygco import cut_from_graph
9
  import open3d as o3d
10
  import matplotlib.pyplot as plt
11
  import matplotlib.colors as mcolors
12
- from stqdm import stqdm
13
- import json
14
- from stpyvista import stpyvista
15
  import torch
16
  import torch.nn as nn
17
  from torch.autograd import Variable
18
  import torch.nn.functional as F
19
  import streamlit as st
20
- import pyvista as pv
21
-
 
22
  from PIL import Image
23
 
 
24
  class TeethApp:
 
 
 
25
  def __init__(self):
26
  # Font
27
  with open("utils/style.css") as css:
@@ -48,44 +50,6 @@ class TeethApp:
48
  unsafe_allow_html=True,
49
  )
50
 
51
-
52
- class STN3d(nn.Module):
53
- def __init__(self, channel):
54
- super(STN3d, self).__init__()
55
- self.conv1 = torch.nn.Conv1d(channel, 64, 1)
56
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
57
- self.conv3 = torch.nn.Conv1d(128, 1024, 1)
58
- self.fc1 = nn.Linear(1024, 512)
59
- self.fc2 = nn.Linear(512, 256)
60
- self.fc3 = nn.Linear(256, 9)
61
- self.relu = nn.ReLU()
62
-
63
- self.bn1 = nn.BatchNorm1d(64)
64
- self.bn2 = nn.BatchNorm1d(128)
65
- self.bn3 = nn.BatchNorm1d(1024)
66
- self.bn4 = nn.BatchNorm1d(512)
67
- self.bn5 = nn.BatchNorm1d(256)
68
-
69
- def forward(self, x):
70
- batchsize = x.size()[0]
71
- x = F.relu(self.bn1(self.conv1(x)))
72
- x = F.relu(self.bn2(self.conv2(x)))
73
- x = F.relu(self.bn3(self.conv3(x)))
74
- x = torch.max(x, 2, keepdim=True)[0]
75
- x = x.view(-1, 1024)
76
-
77
- x = F.relu(self.bn4(self.fc1(x)))
78
- x = F.relu(self.bn5(self.fc2(x)))
79
- x = self.fc3(x)
80
-
81
- iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
82
- batchsize, 1)
83
- if x.is_cuda:
84
- iden = iden.to(x.get_device())
85
- x = x + iden
86
- x = x.view(-1, 3, 3)
87
- return x
88
-
89
  class STNkd(nn.Module):
90
  def __init__(self, k=64):
91
  super(STNkd, self).__init__()
@@ -133,13 +97,15 @@ class MeshSegNet(nn.Module):
133
  self.with_dropout = with_dropout
134
  self.dropout_p = dropout_p
135
 
136
- # MLP-1 [64, 64]
137
  self.mlp1_conv1 = torch.nn.Conv1d(self.num_channels, 64, 1)
138
  self.mlp1_conv2 = torch.nn.Conv1d(64, 64, 1)
139
  self.mlp1_bn1 = nn.BatchNorm1d(64)
140
  self.mlp1_bn2 = nn.BatchNorm1d(64)
 
141
  # FTM (feature-transformer module)
142
  self.fstn = STNkd(k=64)
 
143
  # GLM-1 (graph-contrained learning modulus)
144
  self.glm1_conv1_1 = torch.nn.Conv1d(64, 32, 1)
145
  self.glm1_conv1_2 = torch.nn.Conv1d(64, 32, 1)
@@ -147,6 +113,7 @@ class MeshSegNet(nn.Module):
147
  self.glm1_bn1_2 = nn.BatchNorm1d(32)
148
  self.glm1_conv2 = torch.nn.Conv1d(32+32, 64, 1)
149
  self.glm1_bn2 = nn.BatchNorm1d(64)
 
150
  # MLP-2
151
  self.mlp2_conv1 = torch.nn.Conv1d(64, 64, 1)
152
  self.mlp2_bn1 = nn.BatchNorm1d(64)
@@ -154,6 +121,7 @@ class MeshSegNet(nn.Module):
154
  self.mlp2_bn2 = nn.BatchNorm1d(128)
155
  self.mlp2_conv3 = torch.nn.Conv1d(128, 512, 1)
156
  self.mlp2_bn3 = nn.BatchNorm1d(512)
 
157
  # GLM-2 (graph-contrained learning modulus)
158
  self.glm2_conv1_1 = torch.nn.Conv1d(512, 128, 1)
159
  self.glm2_conv1_2 = torch.nn.Conv1d(512, 128, 1)
@@ -163,6 +131,7 @@ class MeshSegNet(nn.Module):
163
  self.glm2_bn1_3 = nn.BatchNorm1d(128)
164
  self.glm2_conv2 = torch.nn.Conv1d(128*3, 512, 1)
165
  self.glm2_bn2 = nn.BatchNorm1d(512)
 
166
  # MLP-3
167
  self.mlp3_conv1 = torch.nn.Conv1d(64+512+512+512, 256, 1)
168
  self.mlp3_conv2 = torch.nn.Conv1d(256, 256, 1)
@@ -172,7 +141,8 @@ class MeshSegNet(nn.Module):
172
  self.mlp3_conv4 = torch.nn.Conv1d(128, 128, 1)
173
  self.mlp3_bn2_1 = nn.BatchNorm1d(128)
174
  self.mlp3_bn2_2 = nn.BatchNorm1d(128)
175
- # output
 
176
  self.output_conv = torch.nn.Conv1d(128, self.num_classes, 1)
177
  if self.with_dropout:
178
  self.dropout = nn.Dropout(p=self.dropout_p)
@@ -180,13 +150,16 @@ class MeshSegNet(nn.Module):
180
  def forward(self, x, a_s, a_l):
181
  batchsize = x.size()[0]
182
  n_pts = x.size()[2]
 
183
  # MLP-1
184
  x = F.relu(self.mlp1_bn1(self.mlp1_conv1(x)))
185
  x = F.relu(self.mlp1_bn2(self.mlp1_conv2(x)))
 
186
  # FTM
187
  trans_feat = self.fstn(x)
188
  x = x.transpose(2, 1)
189
  x_ftm = torch.bmm(x, trans_feat)
 
190
  # GLM-1
191
  sap = torch.bmm(a_s, x_ftm)
192
  sap = sap.transpose(2, 1)
@@ -195,12 +168,14 @@ class MeshSegNet(nn.Module):
195
  glm_1_sap = F.relu(self.glm1_bn1_2(self.glm1_conv1_2(sap)))
196
  x = torch.cat([x, glm_1_sap], dim=1)
197
  x = F.relu(self.glm1_bn2(self.glm1_conv2(x)))
 
198
  # MLP-2
199
  x = F.relu(self.mlp2_bn1(self.mlp2_conv1(x)))
200
  x = F.relu(self.mlp2_bn2(self.mlp2_conv2(x)))
201
  x_mlp2 = F.relu(self.mlp2_bn3(self.mlp2_conv3(x)))
202
  if self.with_dropout:
203
  x_mlp2 = self.dropout(x_mlp2)
 
204
  # GLM-2
205
  x_mlp2 = x_mlp2.transpose(2, 1)
206
  sap_1 = torch.bmm(a_s, x_mlp2)
@@ -213,12 +188,16 @@ class MeshSegNet(nn.Module):
213
  glm_2_sap_2 = F.relu(self.glm2_bn1_3(self.glm2_conv1_3(sap_2)))
214
  x = torch.cat([x, glm_2_sap_1, glm_2_sap_2], dim=1)
215
  x_glm2 = F.relu(self.glm2_bn2(self.glm2_conv2(x)))
 
216
  # GMP
217
  x = torch.max(x_glm2, 2, keepdim=True)[0]
 
218
  # Upsample
219
  x = torch.nn.Upsample(n_pts)(x)
 
220
  # Dense fusion
221
  x = torch.cat([x, x_ftm, x_mlp2, x_glm2], dim=1)
 
222
  # MLP-3
223
  x = F.relu(self.mlp3_bn1_1(self.mlp3_conv1(x)))
224
  x = F.relu(self.mlp3_bn1_2(self.mlp3_conv2(x)))
@@ -226,6 +205,7 @@ class MeshSegNet(nn.Module):
226
  if self.with_dropout:
227
  x = self.dropout(x)
228
  x = F.relu(self.mlp3_bn2_2(self.mlp3_conv4(x)))
 
229
  # output
230
  x = self.output_conv(x)
231
  x = x.transpose(2,1).contiguous()
@@ -235,20 +215,25 @@ class MeshSegNet(nn.Module):
235
  return x
236
 
237
  def clone_runoob(li1):
 
 
 
238
  li_copy = li1[:]
 
239
  return li_copy
240
 
241
- # 对离群点重新进行分类
242
  def class_inlier_outlier(label_list, mean_points,cloud, ind, label_index, points, labels):
243
  label_change = clone_runoob(labels)
244
  outlier_index = clone_runoob(label_index)
245
  ind_reverse = clone_runoob(ind)
246
- # 得到离群点的label下标
 
247
  ind_reverse.reverse()
248
  for i in ind_reverse:
249
  outlier_index.pop(i)
250
 
251
- # 获取离群点
252
  inlier_cloud = cloud.select_by_index(ind)
253
  outlier_cloud = cloud.select_by_index(ind, invert=True)
254
  outlier_points = np.array(outlier_cloud.points)
@@ -256,29 +241,27 @@ def class_inlier_outlier(label_list, mean_points,cloud, ind, label_index, points
256
  for i in range(len(outlier_points)):
257
  distance = []
258
  for j in range(len(mean_points)):
259
- dis = np.linalg.norm(outlier_points[i] - mean_points[j], ord=2) # 计算toothGT质心之间的距离
260
  distance.append(dis)
261
- min_index = distance.index(min(distance)) # 获取和离群点质心最近label的index
262
- outlier_label = label_list[min_index] # 获取离群点应该的label
263
  index = outlier_index[i]
264
  label_change[index] = outlier_label
265
 
266
  return label_change
267
 
268
- # 利用knn算法消除离群点
269
  def remove_outlier(points, labels):
270
- # points = np.array(point_cloud_o3d_orign.points)
271
- # global label_list
272
  same_label_points = {}
273
 
274
  same_label_index = {}
275
 
276
- mean_points = [] # 所有label种类对应点云的质心坐标
277
 
278
  label_list = []
279
  for i in range(len(labels)):
280
  label_list.append(labels[i])
281
- label_list = list(set(label_list)) # 去重获从小到大排序取GT_label=[0, 11, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26, 27]
282
  label_list.sort()
283
  label_list = label_list[1:]
284
 
@@ -289,7 +272,7 @@ def remove_outlier(points, labels):
289
  for j in range(len(labels)):
290
  if labels[j] == i:
291
  points_list.append(points[j].tolist())
292
- all_label_index.append(j) # 得到label i 的点对应的label的下标
293
  same_label_points[key] = points_list
294
  same_label_index[key] = all_label_index
295
 
@@ -299,106 +282,102 @@ def remove_outlier(points, labels):
299
 
300
  for i in label_list:
301
  points_array = same_label_points[i]
302
- # 建立一个o3d的点云对象
303
  pcd = o3d.geometry.PointCloud()
304
- # 使用Vector3dVector方法转换
305
  pcd.points = o3d.utility.Vector3dVector(points_array)
306
 
307
- # label i 对应的点云进行统计离群值去除,找出离群点并显示
308
- # 统计式离群点移除
309
  cl, ind = pcd.remove_statistical_outlier(nb_neighbors=200, std_ratio=2.0) # cl是选中的点,ind是选中点index
310
- # 可视化
311
- # display_inlier_outlier(pcd, ind)
312
 
313
- # 对分出来的离群点重新分类
314
  label_index = same_label_index[i]
315
  labels = class_inlier_outlier(label_list, mean_points, pcd, ind, label_index, points, labels)
316
  # print(f"label_change{labels[4400]}")
317
 
318
  return labels
319
 
320
-
321
- # 消除离群点,保存最后的输出
322
  def remove_outlier_main(jaw, pcd_points, labels, instances_labels):
323
- # point_cloud_o3d_orign = o3d.io.read_point_cloud('E:/tooth/data/MeshSegNet-master/test_upsample_15/upsample_01K17AN8_upper_refined.pcd')
324
- # 原始点
325
  points = pcd_points.copy()
326
  label = remove_outlier(points, labels)
327
 
328
- # 保存json文件
329
  label_dict = {}
330
  label_dict["id_patient"] = ""
331
  label_dict["jaw"] = jaw
332
  label_dict["labels"] = label.tolist()
333
  label_dict["instances"] = instances_labels.tolist()
 
334
  b = json.dumps(label_dict)
335
  with open('dental-labels4' + '.json', 'w') as f_obj:
336
  f_obj.write(b)
337
  f_obj.close()
338
 
339
-
340
  same_points_list = {}
341
 
342
-
343
- # 体素下采样
344
  def voxel_filter(point_cloud, leaf_size):
345
  same_points_list = {}
346
  filtered_points = []
347
- # step1 计算边界点
 
348
  x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
349
  x_min, y_min, z_min = np.amin(point_cloud, axis=0)
350
 
351
- # step2 确定体素的尺寸
352
  size_r = leaf_size
353
 
354
- # step3 计算每个 volex的维度 voxel grid
355
  Dx = (x_max - x_min) // size_r + 1
356
  Dy = (y_max - y_min) // size_r + 1
357
  Dz = (z_max - z_min) // size_r + 1
358
 
359
  # print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
360
 
361
- # step4 计算每个点在volex grid内每一个维度的值
362
- h = list() # h 为保存索引的列表
363
  for i in range(len(point_cloud)):
364
  hx = np.floor((point_cloud[i][0] - x_min) // size_r)
365
  hy = np.floor((point_cloud[i][1] - y_min) // size_r)
366
  hz = np.floor((point_cloud[i][2] - z_min) // size_r)
367
  h.append(hx + hy * Dx + hz * Dx * Dy)
368
- # print(h[60581])
369
 
370
- # step5 h值进行排序
371
  h = np.array(h)
372
- h_indice = np.argsort(h) # 提取索引,返回h里面的元素按从小到大排序的 索引
373
- h_sorted = h[h_indice] # 升序
374
- count = 0 # 用于维度的累计
375
  step = 20
376
- # 将h值相同的点放入到同一个grid中,并进行筛选
377
- for i in range(1, len(h_sorted)): # 0-19999个数据点
378
- # if i == len(h_sorted)-1:
379
- # print("aaa")
380
  if h_sorted[i] == h_sorted[i - 1] and (i != len(h_sorted) - 1):
381
  continue
 
382
  elif h_sorted[i] == h_sorted[i - 1] and (i == len(h_sorted) - 1):
383
  point_idx = h_indice[count:]
384
  key = h_sorted[i - 1]
385
  same_points_list[key] = point_idx
386
- _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
387
- _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
388
  _d.sort()
389
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
390
  for j in inx:
391
  index = point_idx[j]
392
  filtered_points.append(point_cloud[index])
393
  count = i
 
394
  elif h_sorted[i] != h_sorted[i - 1] and (i == len(h_sorted) - 1):
395
  point_idx1 = h_indice[count:i]
396
  key1 = h_sorted[i - 1]
397
  same_points_list[key1] = point_idx1
398
- _G = np.mean(point_cloud[point_idx1], axis=0) # 所有点的重心
399
- _d = np.linalg.norm(point_cloud[point_idx1] - _G, axis=1, ord=2) # 计算到重心的距离
400
  _d.sort()
401
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
402
  for j in inx:
403
  index = point_idx1[j]
404
  filtered_points.append(point_cloud[index])
@@ -406,10 +385,10 @@ def voxel_filter(point_cloud, leaf_size):
406
  point_idx2 = h_indice[i:]
407
  key2 = h_sorted[i]
408
  same_points_list[key2] = point_idx2
409
- _G = np.mean(point_cloud[point_idx2], axis=0) # 所有点的重心
410
- _d = np.linalg.norm(point_cloud[point_idx2] - _G, axis=1, ord=2) # 计算到重心的距离
411
  _d.sort()
412
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
413
  for j in inx:
414
  index = point_idx2[j]
415
  filtered_points.append(point_cloud[index])
@@ -419,38 +398,42 @@ def voxel_filter(point_cloud, leaf_size):
419
  point_idx = h_indice[count: i]
420
  key = h_sorted[i - 1]
421
  same_points_list[key] = point_idx
422
- _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
423
- _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
424
  _d.sort()
425
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
426
  for j in inx:
427
  index = point_idx[j]
428
  filtered_points.append(point_cloud[index])
429
  count = i
430
 
431
- # 把点云格式改成array,并对外返回
432
  # print(f'filtered_points[0]为{filtered_points[0]}')
433
  filtered_points = np.array(filtered_points, dtype=np.float64)
 
434
  return filtered_points,same_points_list
435
 
436
 
437
- # 体素上采样
438
  def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels, leaf_size):
439
  upsample_label = []
440
  upsample_point = []
441
  upsample_index = []
442
- # step1 计算边界点
443
- x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
 
444
  x_min, y_min, z_min = np.amin(point_cloud, axis=0)
445
- # step2 确定体素的尺寸
 
446
  size_r = leaf_size
447
- # step3 计算每个 volex的维度 voxel grid
 
448
  Dx = (x_max - x_min) // size_r + 1
449
  Dy = (y_max - y_min) // size_r + 1
450
  Dz = (z_max - z_min) // size_r + 1
451
  print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
452
 
453
- # step4 计算每个点(采样后的点)在volex grid内每一个维度的值
454
  h = list()
455
  for i in range(len(filtered_points)):
456
  hx = np.floor((filtered_points[i][0] - x_min) // size_r)
@@ -458,30 +441,33 @@ def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels
458
  hz = np.floor((filtered_points[i][2] - z_min) // size_r)
459
  h.append(hx + hy * Dx + hz * Dx * Dy)
460
 
461
- # step5 根据h值查询字典same_points_list
462
  h = np.array(h)
463
  count = 0
464
  for i in range(1, len(h)):
465
  if h[i] == h[i - 1] and i != (len(h) - 1):
466
  continue
 
467
  elif h[i] == h[i - 1] and i == (len(h) - 1):
468
  label = filter_labels[count:]
469
  key = h[i - 1]
470
  count = i
471
- # 累计label次数,classcount:{‘A’:2,'B':1}
 
472
  classcount = {}
473
  for i in range(len(label)):
474
  vote = label[i]
475
  classcount[vote] = classcount.get(vote, 0) + 1
476
- # 对map的value排序
 
477
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
478
- # key = h[i-1]
479
- point_index = same_points_list[key] # h对应的point index列表
480
  for j in range(len(point_index)):
481
  upsample_label.append(sortedclass[0][0])
482
  index = point_index[j]
483
  upsample_point.append(point_cloud[index])
484
  upsample_index.append(index)
 
485
  elif h[i] != h[i - 1] and (i == len(h) - 1):
486
  label1 = filter_labels[count:i]
487
  key1 = h[i - 1]
@@ -493,8 +479,8 @@ def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels
493
  for i in range(len(label1)):
494
  vote = label1[i]
495
  classcount[vote] = classcount.get(vote, 0) + 1
 
496
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
497
- # key1 = h[i-1]
498
  point_index = same_points_list[key1]
499
  for j in range(len(point_index)):
500
  upsample_label.append(sortedclass[0][0])
@@ -502,13 +488,12 @@ def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels
502
  upsample_point.append(point_cloud[index])
503
  upsample_index.append(index)
504
 
505
- # label2 = filter_labels[i:]
506
  classcount = {}
507
  for i in range(len(label2)):
508
  vote = label2[i]
509
  classcount[vote] = classcount.get(vote, 0) + 1
 
510
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
511
- # key2 = h[i]
512
  point_index = same_points_list[key2]
513
  for j in range(len(point_index)):
514
  upsample_label.append(sortedclass[0][0])
@@ -523,58 +508,51 @@ def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels
523
  for i in range(len(label)):
524
  vote = label[i]
525
  classcount[vote] = classcount.get(vote, 0) + 1
 
526
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
527
- # key = h[i-1]
528
  point_index = same_points_list[key] # h对应的point index列表
529
  for j in range(len(point_index)):
530
  upsample_label.append(sortedclass[0][0])
531
  index = point_index[j]
532
  upsample_point.append(point_cloud[index])
533
  upsample_index.append(index)
534
- # count = i
535
 
536
- # 恢复原始顺序
537
- # print(f'upsample_index[0]的值为{upsample_index[0]}')
538
- # print(f'upsample_index的总长度为{len(upsample_index)}')
539
-
540
- # 恢复index原始顺序
541
  upsample_index = np.array(upsample_index)
542
- upsample_index_indice = np.argsort(upsample_index) # 提取索引,返回h里面的元素按从小到大排序的 索引
543
  upsample_index_sorted = upsample_index[upsample_index_indice]
544
 
545
  upsample_point = np.array(upsample_point)
546
  upsample_label = np.array(upsample_label)
547
- # 恢复point和label的原始顺序
 
548
  upsample_point_sorted = upsample_point[upsample_index_indice]
549
  upsample_label_sorted = upsample_label[upsample_index_indice]
550
 
551
  return upsample_point_sorted, upsample_label_sorted
552
 
553
-
554
- # 利用knn算法上采样
555
  def KNN_sklearn_Load_data(voxel_points, center_points, labels):
556
- # 载入数据
557
- # x_train, x_test, y_train, y_test = train_test_split(center_points, labels, test_size=0.1)
558
- # 构建模型
559
  model = neighbors.KNeighborsClassifier(n_neighbors=3)
560
  model.fit(center_points, labels)
561
  prediction = model.predict(voxel_points.reshape(1, -1))
562
- # meshtopoints_labels = classification_report(voxel_points, prediction)
563
- return prediction[0]
564
 
 
565
 
566
- # 加载点进行knn上采样
567
  def Load_data(voxel_points, center_points, labels):
568
  meshtopoints_labels = []
569
- # meshtopoints_labels.append(SVC_sklearn_Load_data(voxel_points[i], center_points, labels))
570
  for i in range(0, voxel_points.shape[0]):
571
  meshtopoints_labels.append(KNN_sklearn_Load_data(voxel_points[i], center_points, labels))
 
572
  return np.array(meshtopoints_labels)
573
 
574
- # 将三角网格数据上采样回原始点云数据
575
  def mesh_to_points_main(jaw, pcd_points, center_points, labels):
576
  points = pcd_points.copy()
577
- # 下采样
 
578
  voxel_points, same_points_list = voxel_filter(points, 0.6)
579
 
580
  after_labels = Load_data(voxel_points, center_points, labels)
@@ -584,8 +562,8 @@ def mesh_to_points_main(jaw, pcd_points, center_points, labels):
584
  new_pcd = o3d.geometry.PointCloud()
585
  new_pcd.points = o3d.utility.Vector3dVector(upsample_point)
586
  instances_labels = upsample_label.copy()
587
- # '''
588
- # o3d.io.write_point_cloud(os.path.join(save_path, 'upsample_' + name + '.pcd'), new_pcd, write_ascii=True)
589
  for i in stqdm(range(0, upsample_label.shape[0])):
590
  if jaw == 'upper':
591
  if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
@@ -597,13 +575,14 @@ def mesh_to_points_main(jaw, pcd_points, center_points, labels):
597
  upsample_label[i] = upsample_label[i] + 30
598
  elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
599
  upsample_label[i] = upsample_label[i] + 32
 
600
  remove_outlier_main(jaw, pcd_points, upsample_label, instances_labels)
601
 
602
 
603
- # 将原始点云数据转换为三角网格
604
  def mesh_grid(pcd_points):
605
  new_pcd,_ = voxel_filter(pcd_points, 0.6)
606
- # pcd需要有法向量
607
 
608
  # estimate radius for rolling ball
609
  pcd_new = o3d.geometry.PointCloud()
@@ -615,12 +594,10 @@ def mesh_grid(pcd_points):
615
  mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
616
  pcd_new,
617
  o3d.utility.DoubleVector([radius, radius * 2]))
618
- # o3d.io.write_triangle_mesh("./tooth date/test.ply", mesh)
619
 
620
  return mesh
621
 
622
-
623
- # 读取obj文件内容
624
  def read_obj(obj_path):
625
  jaw = None
626
  with open(obj_path) as file:
@@ -642,14 +619,12 @@ def read_obj(obj_path):
642
 
643
  points = np.array(points)
644
  faces = np.array(faces)
645
-
646
  if jaw is None:
647
  raise ValueError("Jaw type not found in OBJ file")
648
 
649
  return points, faces, jaw
650
 
651
-
652
- # obj文件转为pcd文件
653
  def obj2pcd(obj_path):
654
  if os.path.exists(obj_path):
655
  print('yes')
@@ -661,13 +636,14 @@ def obj2pcd(obj_path):
661
  pcd_list.append(new_line.split())
662
 
663
  pcd_points = np.array(pcd_list).astype(np.float64)
664
- return pcd_points, jaw
665
 
 
666
 
 
667
  def segmentation_main(obj_path):
668
  upsampling_method = 'KNN'
669
 
670
- model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
671
  num_classes = 17
672
  num_channels = 15
673
 
@@ -737,6 +713,7 @@ def segmentation_main(obj_path):
737
  nmeans = normals.mean(axis=0)
738
  nstds = normals.std(axis=0)
739
 
 
740
  for i in range(3):
741
  cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
742
  cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
@@ -744,6 +721,7 @@ def segmentation_main(obj_path):
744
  barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
745
  normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
746
 
 
747
  X = np.column_stack((cells, barycenters, normals))
748
 
749
  # computing A_S and A_L
@@ -794,6 +772,7 @@ def segmentation_main(obj_path):
794
  if i_node < i_nei:
795
  cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
796
  normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
 
797
  if cos_theta >= 1.0:
798
  cos_theta = 0.9999
799
  theta = np.arccos(cos_theta)
@@ -806,6 +785,7 @@ def segmentation_main(obj_path):
806
  edges = np.concatenate(
807
  (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
808
  axis=0)
 
809
  edges = np.delete(edges, 0, 0)
810
  edges[:, 2] *= lambda_c * round_factor
811
  edges = edges.astype(np.int32)
@@ -913,9 +893,9 @@ class Segment(TeethApp):
913
  # Create a pyvista plotter
914
  plotter = pv.Plotter()
915
 
916
- cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
917
 
918
- colors = cmap(np.linspace(0, 1, 27)) # Generate colors
919
 
920
  # Convert colors to a format acceptable by PyVista
921
  colormap = mcolors.ListedColormap(colors)
@@ -930,8 +910,6 @@ class Segment(TeethApp):
930
  with st.expander("Ground Truth - scroll for zoom", expanded=False):
931
  stpyvista(plotter)
932
 
933
-
934
-
935
  elif inputs == "Upload Scan":
936
  file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
937
  st.markdown("Expected time per prediction: 7-10 min.")
@@ -939,7 +917,7 @@ class Segment(TeethApp):
939
  # save the uploaded file to disk
940
  with open("file.obj", "wb") as buffer:
941
  shutil.copyfileobj(file, buffer)
942
- # 复制数据
943
  obj_path = "file.obj"
944
 
945
  mesh = pv.read(obj_path)
@@ -957,9 +935,5 @@ class Segment(TeethApp):
957
  if segment:
958
  segmentation_main(obj_path)
959
 
960
-
961
-
962
-
963
-
964
  if __name__ == "__main__":
965
  app = Segment()
 
1
+ import os
2
  import shutil
3
+ import json
4
 
 
5
  import numpy as np
 
6
  from scipy.spatial import distance_matrix
7
+ from sklearn import neighbors
8
  from pygco import cut_from_graph
9
  import open3d as o3d
10
  import matplotlib.pyplot as plt
11
  import matplotlib.colors as mcolors
 
 
 
12
  import torch
13
  import torch.nn as nn
14
  from torch.autograd import Variable
15
  import torch.nn.functional as F
16
  import streamlit as st
17
+ from streamlit import session_state as session
18
+ from stpyvista import stpyvista
19
+ from stqdm import stqdm
20
  from PIL import Image
21
 
22
+ # Configure Streamlit page
23
  class TeethApp:
24
+ """
25
+ Base class for Streamlit app
26
+ """
27
  def __init__(self):
28
  # Font
29
  with open("utils/style.css") as css:
 
50
  unsafe_allow_html=True,
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class STNkd(nn.Module):
54
  def __init__(self, k=64):
55
  super(STNkd, self).__init__()
 
97
  self.with_dropout = with_dropout
98
  self.dropout_p = dropout_p
99
 
100
+ # MLP-1 -shape: [64, 64]
101
  self.mlp1_conv1 = torch.nn.Conv1d(self.num_channels, 64, 1)
102
  self.mlp1_conv2 = torch.nn.Conv1d(64, 64, 1)
103
  self.mlp1_bn1 = nn.BatchNorm1d(64)
104
  self.mlp1_bn2 = nn.BatchNorm1d(64)
105
+
106
  # FTM (feature-transformer module)
107
  self.fstn = STNkd(k=64)
108
+
109
  # GLM-1 (graph-contrained learning modulus)
110
  self.glm1_conv1_1 = torch.nn.Conv1d(64, 32, 1)
111
  self.glm1_conv1_2 = torch.nn.Conv1d(64, 32, 1)
 
113
  self.glm1_bn1_2 = nn.BatchNorm1d(32)
114
  self.glm1_conv2 = torch.nn.Conv1d(32+32, 64, 1)
115
  self.glm1_bn2 = nn.BatchNorm1d(64)
116
+
117
  # MLP-2
118
  self.mlp2_conv1 = torch.nn.Conv1d(64, 64, 1)
119
  self.mlp2_bn1 = nn.BatchNorm1d(64)
 
121
  self.mlp2_bn2 = nn.BatchNorm1d(128)
122
  self.mlp2_conv3 = torch.nn.Conv1d(128, 512, 1)
123
  self.mlp2_bn3 = nn.BatchNorm1d(512)
124
+
125
  # GLM-2 (graph-contrained learning modulus)
126
  self.glm2_conv1_1 = torch.nn.Conv1d(512, 128, 1)
127
  self.glm2_conv1_2 = torch.nn.Conv1d(512, 128, 1)
 
131
  self.glm2_bn1_3 = nn.BatchNorm1d(128)
132
  self.glm2_conv2 = torch.nn.Conv1d(128*3, 512, 1)
133
  self.glm2_bn2 = nn.BatchNorm1d(512)
134
+
135
  # MLP-3
136
  self.mlp3_conv1 = torch.nn.Conv1d(64+512+512+512, 256, 1)
137
  self.mlp3_conv2 = torch.nn.Conv1d(256, 256, 1)
 
141
  self.mlp3_conv4 = torch.nn.Conv1d(128, 128, 1)
142
  self.mlp3_bn2_1 = nn.BatchNorm1d(128)
143
  self.mlp3_bn2_2 = nn.BatchNorm1d(128)
144
+
145
+ # Output
146
  self.output_conv = torch.nn.Conv1d(128, self.num_classes, 1)
147
  if self.with_dropout:
148
  self.dropout = nn.Dropout(p=self.dropout_p)
 
150
  def forward(self, x, a_s, a_l):
151
  batchsize = x.size()[0]
152
  n_pts = x.size()[2]
153
+
154
  # MLP-1
155
  x = F.relu(self.mlp1_bn1(self.mlp1_conv1(x)))
156
  x = F.relu(self.mlp1_bn2(self.mlp1_conv2(x)))
157
+
158
  # FTM
159
  trans_feat = self.fstn(x)
160
  x = x.transpose(2, 1)
161
  x_ftm = torch.bmm(x, trans_feat)
162
+
163
  # GLM-1
164
  sap = torch.bmm(a_s, x_ftm)
165
  sap = sap.transpose(2, 1)
 
168
  glm_1_sap = F.relu(self.glm1_bn1_2(self.glm1_conv1_2(sap)))
169
  x = torch.cat([x, glm_1_sap], dim=1)
170
  x = F.relu(self.glm1_bn2(self.glm1_conv2(x)))
171
+
172
  # MLP-2
173
  x = F.relu(self.mlp2_bn1(self.mlp2_conv1(x)))
174
  x = F.relu(self.mlp2_bn2(self.mlp2_conv2(x)))
175
  x_mlp2 = F.relu(self.mlp2_bn3(self.mlp2_conv3(x)))
176
  if self.with_dropout:
177
  x_mlp2 = self.dropout(x_mlp2)
178
+
179
  # GLM-2
180
  x_mlp2 = x_mlp2.transpose(2, 1)
181
  sap_1 = torch.bmm(a_s, x_mlp2)
 
188
  glm_2_sap_2 = F.relu(self.glm2_bn1_3(self.glm2_conv1_3(sap_2)))
189
  x = torch.cat([x, glm_2_sap_1, glm_2_sap_2], dim=1)
190
  x_glm2 = F.relu(self.glm2_bn2(self.glm2_conv2(x)))
191
+
192
  # GMP
193
  x = torch.max(x_glm2, 2, keepdim=True)[0]
194
+
195
  # Upsample
196
  x = torch.nn.Upsample(n_pts)(x)
197
+
198
  # Dense fusion
199
  x = torch.cat([x, x_ftm, x_mlp2, x_glm2], dim=1)
200
+
201
  # MLP-3
202
  x = F.relu(self.mlp3_bn1_1(self.mlp3_conv1(x)))
203
  x = F.relu(self.mlp3_bn1_2(self.mlp3_conv2(x)))
 
205
  if self.with_dropout:
206
  x = self.dropout(x)
207
  x = F.relu(self.mlp3_bn2_2(self.mlp3_conv4(x)))
208
+
209
  # output
210
  x = self.output_conv(x)
211
  x = x.transpose(2,1).contiguous()
 
215
  return x
216
 
217
  def clone_runoob(li1):
218
+ """
219
+ copy list
220
+ """
221
  li_copy = li1[:]
222
+
223
  return li_copy
224
 
225
+ # Reclassify outliers
226
  def class_inlier_outlier(label_list, mean_points,cloud, ind, label_index, points, labels):
227
  label_change = clone_runoob(labels)
228
  outlier_index = clone_runoob(label_index)
229
  ind_reverse = clone_runoob(ind)
230
+
231
+ # Get the label subscript of the outlier point
232
  ind_reverse.reverse()
233
  for i in ind_reverse:
234
  outlier_index.pop(i)
235
 
236
+ # Get outliers
237
  inlier_cloud = cloud.select_by_index(ind)
238
  outlier_cloud = cloud.select_by_index(ind, invert=True)
239
  outlier_points = np.array(outlier_cloud.points)
 
241
  for i in range(len(outlier_points)):
242
  distance = []
243
  for j in range(len(mean_points)):
244
+ dis = np.linalg.norm(outlier_points[i] - mean_points[j], ord=2) # Compute the distance between tooth and GT centroid
245
  distance.append(dis)
246
+ min_index = distance.index(min(distance)) # Get the index of the label closest to the centroid of the outlier point
247
+ outlier_label = label_list[min_index] # Get the label of the outlier point
248
  index = outlier_index[i]
249
  label_change[index] = outlier_label
250
 
251
  return label_change
252
 
253
+ # Use knn algorithm to eliminate outliers
254
  def remove_outlier(points, labels):
 
 
255
  same_label_points = {}
256
 
257
  same_label_index = {}
258
 
259
+ mean_points = [] # All label types correspond to the centroid coordinates of the point cloud.
260
 
261
  label_list = []
262
  for i in range(len(labels)):
263
  label_list.append(labels[i])
264
+ label_list = list(set(label_list)) # To retrieve the order from small to large, take GT_label=[0, 11, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26, 27]
265
  label_list.sort()
266
  label_list = label_list[1:]
267
 
 
272
  for j in range(len(labels)):
273
  if labels[j] == i:
274
  points_list.append(points[j].tolist())
275
+ all_label_index.append(j) # Get the subscript of the label corresponding to the point with label i
276
  same_label_points[key] = points_list
277
  same_label_index[key] = all_label_index
278
 
 
282
 
283
  for i in label_list:
284
  points_array = same_label_points[i]
285
+ # Build one o3d object
286
  pcd = o3d.geometry.PointCloud()
287
+ # UseVector3dVector conversion method
288
  pcd.points = o3d.utility.Vector3dVector(points_array)
289
 
290
+ # Perform statistical outlier removal on the point cloud corresponding to label i, find outliers and display them
291
+ # Statistical outlier removal
292
  cl, ind = pcd.remove_statistical_outlier(nb_neighbors=200, std_ratio=2.0) # cl是选中的点,ind是选中点index
 
 
293
 
294
+ # Reclassify the separated outliers
295
  label_index = same_label_index[i]
296
  labels = class_inlier_outlier(label_list, mean_points, pcd, ind, label_index, points, labels)
297
  # print(f"label_change{labels[4400]}")
298
 
299
  return labels
300
 
301
+ # Eliminate outliers and save the final output
 
302
  def remove_outlier_main(jaw, pcd_points, labels, instances_labels):
303
+ # original point
 
304
  points = pcd_points.copy()
305
  label = remove_outlier(points, labels)
306
 
307
+ # Save json file
308
  label_dict = {}
309
  label_dict["id_patient"] = ""
310
  label_dict["jaw"] = jaw
311
  label_dict["labels"] = label.tolist()
312
  label_dict["instances"] = instances_labels.tolist()
313
+
314
  b = json.dumps(label_dict)
315
  with open('dental-labels4' + '.json', 'w') as f_obj:
316
  f_obj.write(b)
317
  f_obj.close()
318
 
 
319
  same_points_list = {}
320
 
321
+ # voxel downsampling
 
322
  def voxel_filter(point_cloud, leaf_size):
323
  same_points_list = {}
324
  filtered_points = []
325
+
326
+ # step1 Calculate boundary points
327
  x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
328
  x_min, y_min, z_min = np.amin(point_cloud, axis=0)
329
 
330
+ # step2 Determine the size of the voxel
331
  size_r = leaf_size
332
 
333
+ # step3 Calculate the dimensions of each volex voxel grid
334
  Dx = (x_max - x_min) // size_r + 1
335
  Dy = (y_max - y_min) // size_r + 1
336
  Dz = (z_max - z_min) // size_r + 1
337
 
338
  # print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
339
 
340
+ # step4 Calculate the value of each point in each dimension in the volex grid
341
+ h = list() # h is a list of saved indexes
342
  for i in range(len(point_cloud)):
343
  hx = np.floor((point_cloud[i][0] - x_min) // size_r)
344
  hy = np.floor((point_cloud[i][1] - y_min) // size_r)
345
  hz = np.floor((point_cloud[i][2] - z_min) // size_r)
346
  h.append(hx + hy * Dx + hz * Dx * Dy)
 
347
 
348
+ # step5 Sort h values
349
  h = np.array(h)
350
+ h_indice = np.argsort(h) # Extract the index and return the index of the elements in h sorted from small to large.
351
+ h_sorted = h[h_indice] # Ascending order
352
+ count = 0 # used for accumulation of dimensions
353
  step = 20
354
+
355
+ # Put points with the same h value into the same grid and filter them
356
+ for i in range(1, len(h_sorted)): # 0-19999 data points
 
357
  if h_sorted[i] == h_sorted[i - 1] and (i != len(h_sorted) - 1):
358
  continue
359
+
360
  elif h_sorted[i] == h_sorted[i - 1] and (i == len(h_sorted) - 1):
361
  point_idx = h_indice[count:]
362
  key = h_sorted[i - 1]
363
  same_points_list[key] = point_idx
364
+ _G = np.mean(point_cloud[point_idx], axis=0) # center of gravity of all points
365
+ _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # Calculate distance to center of gravity
366
  _d.sort()
367
+ inx = [j for j in range(0, len(_d), step)] # Get the index of the specified interval element
368
  for j in inx:
369
  index = point_idx[j]
370
  filtered_points.append(point_cloud[index])
371
  count = i
372
+
373
  elif h_sorted[i] != h_sorted[i - 1] and (i == len(h_sorted) - 1):
374
  point_idx1 = h_indice[count:i]
375
  key1 = h_sorted[i - 1]
376
  same_points_list[key1] = point_idx1
377
+ _G = np.mean(point_cloud[point_idx1], axis=0) # center of gravity of all points
378
+ _d = np.linalg.norm(point_cloud[point_idx1] - _G, axis=1, ord=2) # Calculate distance to center of gravity
379
  _d.sort()
380
+ inx = [j for j in range(0, len(_d), step)] # Get the index of the specified interval element
381
  for j in inx:
382
  index = point_idx1[j]
383
  filtered_points.append(point_cloud[index])
 
385
  point_idx2 = h_indice[i:]
386
  key2 = h_sorted[i]
387
  same_points_list[key2] = point_idx2
388
+ _G = np.mean(point_cloud[point_idx2], axis=0) # center of gravity of all points
389
+ _d = np.linalg.norm(point_cloud[point_idx2] - _G, axis=1, ord=2) # Calculate distance to center of gravity
390
  _d.sort()
391
+ inx = [j for j in range(0, len(_d), step)] # Get the index of the specified interval element
392
  for j in inx:
393
  index = point_idx2[j]
394
  filtered_points.append(point_cloud[index])
 
398
  point_idx = h_indice[count: i]
399
  key = h_sorted[i - 1]
400
  same_points_list[key] = point_idx
401
+ _G = np.mean(point_cloud[point_idx], axis=0) # center of gravity of all points
402
+ _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # Calculate distance to center of gravity
403
  _d.sort()
404
+ inx = [j for j in range(0, len(_d), step)] # Get the index of the specified interval element
405
  for j in inx:
406
  index = point_idx[j]
407
  filtered_points.append(point_cloud[index])
408
  count = i
409
 
410
+ # Change the point cloud format to array and return it externally
411
  # print(f'filtered_points[0]为{filtered_points[0]}')
412
  filtered_points = np.array(filtered_points, dtype=np.float64)
413
+
414
  return filtered_points,same_points_list
415
 
416
 
417
+ # voxel upsampling
418
  def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels, leaf_size):
419
  upsample_label = []
420
  upsample_point = []
421
  upsample_index = []
422
+
423
+ # step1 Calculate boundary points
424
+ x_max, y_max, z_max = np.amax(point_cloud, axis=0) # Calculate the maximum value of the three dimensions x, y, z
425
  x_min, y_min, z_min = np.amin(point_cloud, axis=0)
426
+
427
+ # step2 Determine the size of the voxel
428
  size_r = leaf_size
429
+
430
+ # step3 Calculate the dimensions of each volex voxel grid
431
  Dx = (x_max - x_min) // size_r + 1
432
  Dy = (y_max - y_min) // size_r + 1
433
  Dz = (z_max - z_min) // size_r + 1
434
  print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
435
 
436
+ # step4 Calculate the value of each point (sampled point) in each dimension within the volex grid
437
  h = list()
438
  for i in range(len(filtered_points)):
439
  hx = np.floor((filtered_points[i][0] - x_min) // size_r)
 
441
  hz = np.floor((filtered_points[i][2] - z_min) // size_r)
442
  h.append(hx + hy * Dx + hz * Dx * Dy)
443
 
444
+ # step5 Query the dictionary same_points_list based on the h value
445
  h = np.array(h)
446
  count = 0
447
  for i in range(1, len(h)):
448
  if h[i] == h[i - 1] and i != (len(h) - 1):
449
  continue
450
+
451
  elif h[i] == h[i - 1] and i == (len(h) - 1):
452
  label = filter_labels[count:]
453
  key = h[i - 1]
454
  count = i
455
+
456
+ # Cumulative number of labels, classcount: {‘A’: 2, ‘B’: 1}
457
  classcount = {}
458
  for i in range(len(label)):
459
  vote = label[i]
460
  classcount[vote] = classcount.get(vote, 0) + 1
461
+
462
+ # Sort map values
463
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
464
+ point_index = same_points_list[key] # Point index list corresponding to h
 
465
  for j in range(len(point_index)):
466
  upsample_label.append(sortedclass[0][0])
467
  index = point_index[j]
468
  upsample_point.append(point_cloud[index])
469
  upsample_index.append(index)
470
+
471
  elif h[i] != h[i - 1] and (i == len(h) - 1):
472
  label1 = filter_labels[count:i]
473
  key1 = h[i - 1]
 
479
  for i in range(len(label1)):
480
  vote = label1[i]
481
  classcount[vote] = classcount.get(vote, 0) + 1
482
+
483
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
 
484
  point_index = same_points_list[key1]
485
  for j in range(len(point_index)):
486
  upsample_label.append(sortedclass[0][0])
 
488
  upsample_point.append(point_cloud[index])
489
  upsample_index.append(index)
490
 
 
491
  classcount = {}
492
  for i in range(len(label2)):
493
  vote = label2[i]
494
  classcount[vote] = classcount.get(vote, 0) + 1
495
+
496
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
 
497
  point_index = same_points_list[key2]
498
  for j in range(len(point_index)):
499
  upsample_label.append(sortedclass[0][0])
 
508
  for i in range(len(label)):
509
  vote = label[i]
510
  classcount[vote] = classcount.get(vote, 0) + 1
511
+
512
  sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
 
513
  point_index = same_points_list[key] # h对应的point index列表
514
  for j in range(len(point_index)):
515
  upsample_label.append(sortedclass[0][0])
516
  index = point_index[j]
517
  upsample_point.append(point_cloud[index])
518
  upsample_index.append(index)
 
519
 
520
+ # Restore the original order of index
 
 
 
 
521
  upsample_index = np.array(upsample_index)
522
+ upsample_index_indice = np.argsort(upsample_index) # Extract the index and return the index of the elements in h sorted from small to large.
523
  upsample_index_sorted = upsample_index[upsample_index_indice]
524
 
525
  upsample_point = np.array(upsample_point)
526
  upsample_label = np.array(upsample_label)
527
+
528
+ # Restore the original order of points and labels
529
  upsample_point_sorted = upsample_point[upsample_index_indice]
530
  upsample_label_sorted = upsample_label[upsample_index_indice]
531
 
532
  return upsample_point_sorted, upsample_label_sorted
533
 
534
+ # Upsampling using knn algorithm
 
535
  def KNN_sklearn_Load_data(voxel_points, center_points, labels):
536
+ # Build model
 
 
537
  model = neighbors.KNeighborsClassifier(n_neighbors=3)
538
  model.fit(center_points, labels)
539
  prediction = model.predict(voxel_points.reshape(1, -1))
 
 
540
 
541
+ return prediction[0]
542
 
543
+ # Loading points for knn upsampling
544
  def Load_data(voxel_points, center_points, labels):
545
  meshtopoints_labels = []
 
546
  for i in range(0, voxel_points.shape[0]):
547
  meshtopoints_labels.append(KNN_sklearn_Load_data(voxel_points[i], center_points, labels))
548
+
549
  return np.array(meshtopoints_labels)
550
 
551
+ # Upsample triangular mesh data back to original point cloud data
552
  def mesh_to_points_main(jaw, pcd_points, center_points, labels):
553
  points = pcd_points.copy()
554
+
555
+ # Downsampling
556
  voxel_points, same_points_list = voxel_filter(points, 0.6)
557
 
558
  after_labels = Load_data(voxel_points, center_points, labels)
 
562
  new_pcd = o3d.geometry.PointCloud()
563
  new_pcd.points = o3d.utility.Vector3dVector(upsample_point)
564
  instances_labels = upsample_label.copy()
565
+
566
+ # Reclassify the label of the upper and lower jaws
567
  for i in stqdm(range(0, upsample_label.shape[0])):
568
  if jaw == 'upper':
569
  if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
 
575
  upsample_label[i] = upsample_label[i] + 30
576
  elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
577
  upsample_label[i] = upsample_label[i] + 32
578
+
579
  remove_outlier_main(jaw, pcd_points, upsample_label, instances_labels)
580
 
581
 
582
+ # Convert raw point cloud data to triangular mesh
583
  def mesh_grid(pcd_points):
584
  new_pcd,_ = voxel_filter(pcd_points, 0.6)
585
+ # pcd needs to have a normal vector
586
 
587
  # estimate radius for rolling ball
588
  pcd_new = o3d.geometry.PointCloud()
 
594
  mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
595
  pcd_new,
596
  o3d.utility.DoubleVector([radius, radius * 2]))
 
597
 
598
  return mesh
599
 
600
+ # Read the contents of obj file
 
601
  def read_obj(obj_path):
602
  jaw = None
603
  with open(obj_path) as file:
 
619
 
620
  points = np.array(points)
621
  faces = np.array(faces)
 
622
  if jaw is None:
623
  raise ValueError("Jaw type not found in OBJ file")
624
 
625
  return points, faces, jaw
626
 
627
+ # Convert obj file to pcd file
 
628
  def obj2pcd(obj_path):
629
  if os.path.exists(obj_path):
630
  print('yes')
 
636
  pcd_list.append(new_line.split())
637
 
638
  pcd_points = np.array(pcd_list).astype(np.float64)
 
639
 
640
+ return pcd_points, jaw
641
 
642
+ # Main function for segment
643
  def segmentation_main(obj_path):
644
  upsampling_method = 'KNN'
645
 
646
+ model_path = 'model.tar'
647
  num_classes = 17
648
  num_channels = 15
649
 
 
713
  nmeans = normals.mean(axis=0)
714
  nstds = normals.std(axis=0)
715
 
716
+ # normalization
717
  for i in range(3):
718
  cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
719
  cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
 
721
  barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
722
  normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
723
 
724
+ # concatenate
725
  X = np.column_stack((cells, barycenters, normals))
726
 
727
  # computing A_S and A_L
 
772
  if i_node < i_nei:
773
  cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
774
  normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
775
+
776
  if cos_theta >= 1.0:
777
  cos_theta = 0.9999
778
  theta = np.arccos(cos_theta)
 
785
  edges = np.concatenate(
786
  (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
787
  axis=0)
788
+
789
  edges = np.delete(edges, 0, 0)
790
  edges[:, 2] *= lambda_c * round_factor
791
  edges = edges.astype(np.int32)
 
893
  # Create a pyvista plotter
894
  plotter = pv.Plotter()
895
 
896
+ cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
897
 
898
+ colors = cmap(np.linspace(0, 1, 27)) # Generate colors
899
 
900
  # Convert colors to a format acceptable by PyVista
901
  colormap = mcolors.ListedColormap(colors)
 
910
  with st.expander("Ground Truth - scroll for zoom", expanded=False):
911
  stpyvista(plotter)
912
 
 
 
913
  elif inputs == "Upload Scan":
914
  file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
915
  st.markdown("Expected time per prediction: 7-10 min.")
 
917
  # save the uploaded file to disk
918
  with open("file.obj", "wb") as buffer:
919
  shutil.copyfileobj(file, buffer)
920
+
921
  obj_path = "file.obj"
922
 
923
  mesh = pv.read(obj_path)
 
935
  if segment:
936
  segmentation_main(obj_path)
937
 
 
 
 
 
938
  if __name__ == "__main__":
939
  app = Segment()