huathedev commited on
Commit
d58d621
·
1 Parent(s): 59ebb86

Update ⓘ_Introduction.py

Browse files
Files changed (1) hide show
  1. ⓘ_Introduction.py +890 -10
ⓘ_Introduction.py CHANGED
@@ -1,5 +1,24 @@
1
- import streamlit as st
2
  from streamlit import session_state as session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from PIL import Image
5
 
@@ -8,7 +27,7 @@ class TeethApp:
8
  # Font
9
  with open("utils/style.css") as css:
10
  st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
11
-
12
  # Logo
13
  self.image_path = "utils/teeth-295404_1280.png"
14
  self.image = Image.open(self.image_path)
@@ -30,20 +49,881 @@ class TeethApp:
30
  unsafe_allow_html=True,
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Configure Streamlit page
34
- st.set_page_config(page_title="Teeth Segmentation", page_icon="")
35
 
36
- class Intro(TeethApp):
37
  def __init__(self):
38
  TeethApp.__init__(self)
39
  self.build_app()
40
 
41
  def build_app(self):
42
- st.title("AI-assited Tooth Segmentation")
43
- st.markdown("This app automatically segments intra-oral scans of teeth using machine learning.")
44
- st.markdown("Head to the 'Segment' tab to try it out!")
45
- st.markdown("**Example:**")
46
- st.image("illu.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
- app = Intro()
 
 
1
  from streamlit import session_state as session
2
+ import shutil
3
+
4
+ import os
5
+ import numpy as np
6
+ from sklearn import neighbors
7
+ from scipy.spatial import distance_matrix
8
+ from pygco import cut_from_graph
9
+ import streamlit_ext as ste
10
+ import open3d as o3d
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.colors as mcolors
13
+ from stqdm import stqdm
14
+ import json
15
+ from stpyvista import stpyvista
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.autograd import Variable
19
+ import torch.nn.functional as F
20
+ import streamlit as st
21
+ import pyvista as pv
22
 
23
  from PIL import Image
24
 
 
27
  # Font
28
  with open("utils/style.css") as css:
29
  st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
30
+
31
  # Logo
32
  self.image_path = "utils/teeth-295404_1280.png"
33
  self.image = Image.open(self.image_path)
 
49
  unsafe_allow_html=True,
50
  )
51
 
52
+
53
+ class STN3d(nn.Module):
54
+ def __init__(self, channel):
55
+ super(STN3d, self).__init__()
56
+ self.conv1 = torch.nn.Conv1d(channel, 64, 1)
57
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
58
+ self.conv3 = torch.nn.Conv1d(128, 1024, 1)
59
+ self.fc1 = nn.Linear(1024, 512)
60
+ self.fc2 = nn.Linear(512, 256)
61
+ self.fc3 = nn.Linear(256, 9)
62
+ self.relu = nn.ReLU()
63
+
64
+ self.bn1 = nn.BatchNorm1d(64)
65
+ self.bn2 = nn.BatchNorm1d(128)
66
+ self.bn3 = nn.BatchNorm1d(1024)
67
+ self.bn4 = nn.BatchNorm1d(512)
68
+ self.bn5 = nn.BatchNorm1d(256)
69
+
70
+ def forward(self, x):
71
+ batchsize = x.size()[0]
72
+ x = F.relu(self.bn1(self.conv1(x)))
73
+ x = F.relu(self.bn2(self.conv2(x)))
74
+ x = F.relu(self.bn3(self.conv3(x)))
75
+ x = torch.max(x, 2, keepdim=True)[0]
76
+ x = x.view(-1, 1024)
77
+
78
+ x = F.relu(self.bn4(self.fc1(x)))
79
+ x = F.relu(self.bn5(self.fc2(x)))
80
+ x = self.fc3(x)
81
+
82
+ iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
83
+ batchsize, 1)
84
+ if x.is_cuda:
85
+ iden = iden.to(x.get_device())
86
+ x = x + iden
87
+ x = x.view(-1, 3, 3)
88
+ return x
89
+
90
+ class STNkd(nn.Module):
91
+ def __init__(self, k=64):
92
+ super(STNkd, self).__init__()
93
+ self.conv1 = torch.nn.Conv1d(k, 64, 1)
94
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
95
+ self.conv3 = torch.nn.Conv1d(128, 512, 1)
96
+ self.fc1 = nn.Linear(512, 256)
97
+ self.fc2 = nn.Linear(256, 128)
98
+ self.fc3 = nn.Linear(128, k * k)
99
+ self.relu = nn.ReLU()
100
+
101
+ self.bn1 = nn.BatchNorm1d(64)
102
+ self.bn2 = nn.BatchNorm1d(128)
103
+ self.bn3 = nn.BatchNorm1d(512)
104
+ self.bn4 = nn.BatchNorm1d(256)
105
+ self.bn5 = nn.BatchNorm1d(128)
106
+
107
+ self.k = k
108
+
109
+ def forward(self, x):
110
+ batchsize = x.size()[0]
111
+ x = F.relu(self.bn1(self.conv1(x)))
112
+ x = F.relu(self.bn2(self.conv2(x)))
113
+ x = F.relu(self.bn3(self.conv3(x)))
114
+ x = torch.max(x, 2, keepdim=True)[0]
115
+ x = x.view(-1, 512)
116
+
117
+ x = F.relu(self.bn4(self.fc1(x)))
118
+ x = F.relu(self.bn5(self.fc2(x)))
119
+ x = self.fc3(x)
120
+
121
+ iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
122
+ batchsize, 1)
123
+ if x.is_cuda:
124
+ iden = iden.to(x.get_device())
125
+ x = x + iden
126
+ x = x.view(-1, self.k, self.k)
127
+ return x
128
+
129
+ class MeshSegNet(nn.Module):
130
+ def __init__(self, num_classes=17, num_channels=15, with_dropout=True, dropout_p=0.5):
131
+ super(MeshSegNet, self).__init__()
132
+ self.num_classes = num_classes
133
+ self.num_channels = num_channels
134
+ self.with_dropout = with_dropout
135
+ self.dropout_p = dropout_p
136
+
137
+ # MLP-1 [64, 64]
138
+ self.mlp1_conv1 = torch.nn.Conv1d(self.num_channels, 64, 1)
139
+ self.mlp1_conv2 = torch.nn.Conv1d(64, 64, 1)
140
+ self.mlp1_bn1 = nn.BatchNorm1d(64)
141
+ self.mlp1_bn2 = nn.BatchNorm1d(64)
142
+ # FTM (feature-transformer module)
143
+ self.fstn = STNkd(k=64)
144
+ # GLM-1 (graph-contrained learning modulus)
145
+ self.glm1_conv1_1 = torch.nn.Conv1d(64, 32, 1)
146
+ self.glm1_conv1_2 = torch.nn.Conv1d(64, 32, 1)
147
+ self.glm1_bn1_1 = nn.BatchNorm1d(32)
148
+ self.glm1_bn1_2 = nn.BatchNorm1d(32)
149
+ self.glm1_conv2 = torch.nn.Conv1d(32+32, 64, 1)
150
+ self.glm1_bn2 = nn.BatchNorm1d(64)
151
+ # MLP-2
152
+ self.mlp2_conv1 = torch.nn.Conv1d(64, 64, 1)
153
+ self.mlp2_bn1 = nn.BatchNorm1d(64)
154
+ self.mlp2_conv2 = torch.nn.Conv1d(64, 128, 1)
155
+ self.mlp2_bn2 = nn.BatchNorm1d(128)
156
+ self.mlp2_conv3 = torch.nn.Conv1d(128, 512, 1)
157
+ self.mlp2_bn3 = nn.BatchNorm1d(512)
158
+ # GLM-2 (graph-contrained learning modulus)
159
+ self.glm2_conv1_1 = torch.nn.Conv1d(512, 128, 1)
160
+ self.glm2_conv1_2 = torch.nn.Conv1d(512, 128, 1)
161
+ self.glm2_conv1_3 = torch.nn.Conv1d(512, 128, 1)
162
+ self.glm2_bn1_1 = nn.BatchNorm1d(128)
163
+ self.glm2_bn1_2 = nn.BatchNorm1d(128)
164
+ self.glm2_bn1_3 = nn.BatchNorm1d(128)
165
+ self.glm2_conv2 = torch.nn.Conv1d(128*3, 512, 1)
166
+ self.glm2_bn2 = nn.BatchNorm1d(512)
167
+ # MLP-3
168
+ self.mlp3_conv1 = torch.nn.Conv1d(64+512+512+512, 256, 1)
169
+ self.mlp3_conv2 = torch.nn.Conv1d(256, 256, 1)
170
+ self.mlp3_bn1_1 = nn.BatchNorm1d(256)
171
+ self.mlp3_bn1_2 = nn.BatchNorm1d(256)
172
+ self.mlp3_conv3 = torch.nn.Conv1d(256, 128, 1)
173
+ self.mlp3_conv4 = torch.nn.Conv1d(128, 128, 1)
174
+ self.mlp3_bn2_1 = nn.BatchNorm1d(128)
175
+ self.mlp3_bn2_2 = nn.BatchNorm1d(128)
176
+ # output
177
+ self.output_conv = torch.nn.Conv1d(128, self.num_classes, 1)
178
+ if self.with_dropout:
179
+ self.dropout = nn.Dropout(p=self.dropout_p)
180
+
181
+ def forward(self, x, a_s, a_l):
182
+ batchsize = x.size()[0]
183
+ n_pts = x.size()[2]
184
+ # MLP-1
185
+ x = F.relu(self.mlp1_bn1(self.mlp1_conv1(x)))
186
+ x = F.relu(self.mlp1_bn2(self.mlp1_conv2(x)))
187
+ # FTM
188
+ trans_feat = self.fstn(x)
189
+ x = x.transpose(2, 1)
190
+ x_ftm = torch.bmm(x, trans_feat)
191
+ # GLM-1
192
+ sap = torch.bmm(a_s, x_ftm)
193
+ sap = sap.transpose(2, 1)
194
+ x_ftm = x_ftm.transpose(2, 1)
195
+ x = F.relu(self.glm1_bn1_1(self.glm1_conv1_1(x_ftm)))
196
+ glm_1_sap = F.relu(self.glm1_bn1_2(self.glm1_conv1_2(sap)))
197
+ x = torch.cat([x, glm_1_sap], dim=1)
198
+ x = F.relu(self.glm1_bn2(self.glm1_conv2(x)))
199
+ # MLP-2
200
+ x = F.relu(self.mlp2_bn1(self.mlp2_conv1(x)))
201
+ x = F.relu(self.mlp2_bn2(self.mlp2_conv2(x)))
202
+ x_mlp2 = F.relu(self.mlp2_bn3(self.mlp2_conv3(x)))
203
+ if self.with_dropout:
204
+ x_mlp2 = self.dropout(x_mlp2)
205
+ # GLM-2
206
+ x_mlp2 = x_mlp2.transpose(2, 1)
207
+ sap_1 = torch.bmm(a_s, x_mlp2)
208
+ sap_2 = torch.bmm(a_l, x_mlp2)
209
+ x_mlp2 = x_mlp2.transpose(2, 1)
210
+ sap_1 = sap_1.transpose(2, 1)
211
+ sap_2 = sap_2.transpose(2, 1)
212
+ x = F.relu(self.glm2_bn1_1(self.glm2_conv1_1(x_mlp2)))
213
+ glm_2_sap_1 = F.relu(self.glm2_bn1_2(self.glm2_conv1_2(sap_1)))
214
+ glm_2_sap_2 = F.relu(self.glm2_bn1_3(self.glm2_conv1_3(sap_2)))
215
+ x = torch.cat([x, glm_2_sap_1, glm_2_sap_2], dim=1)
216
+ x_glm2 = F.relu(self.glm2_bn2(self.glm2_conv2(x)))
217
+ # GMP
218
+ x = torch.max(x_glm2, 2, keepdim=True)[0]
219
+ # Upsample
220
+ x = torch.nn.Upsample(n_pts)(x)
221
+ # Dense fusion
222
+ x = torch.cat([x, x_ftm, x_mlp2, x_glm2], dim=1)
223
+ # MLP-3
224
+ x = F.relu(self.mlp3_bn1_1(self.mlp3_conv1(x)))
225
+ x = F.relu(self.mlp3_bn1_2(self.mlp3_conv2(x)))
226
+ x = F.relu(self.mlp3_bn2_1(self.mlp3_conv3(x)))
227
+ if self.with_dropout:
228
+ x = self.dropout(x)
229
+ x = F.relu(self.mlp3_bn2_2(self.mlp3_conv4(x)))
230
+ # output
231
+ x = self.output_conv(x)
232
+ x = x.transpose(2,1).contiguous()
233
+ x = torch.nn.Softmax(dim=-1)(x.view(-1, self.num_classes))
234
+ x = x.view(batchsize, n_pts, self.num_classes)
235
+
236
+ return x
237
+
238
+ def clone_runoob(li1):
239
+ li_copy = li1[:]
240
+ return li_copy
241
+
242
+ # 对离群点重新进行分类
243
+ def class_inlier_outlier(label_list, mean_points,cloud, ind, label_index, points, labels):
244
+ label_change = clone_runoob(labels)
245
+ outlier_index = clone_runoob(label_index)
246
+ ind_reverse = clone_runoob(ind)
247
+ # 得到离群点的label下标
248
+ ind_reverse.reverse()
249
+ for i in ind_reverse:
250
+ outlier_index.pop(i)
251
+
252
+ # 获取离群点
253
+ inlier_cloud = cloud.select_by_index(ind)
254
+ outlier_cloud = cloud.select_by_index(ind, invert=True)
255
+ outlier_points = np.array(outlier_cloud.points)
256
+
257
+ for i in range(len(outlier_points)):
258
+ distance = []
259
+ for j in range(len(mean_points)):
260
+ dis = np.linalg.norm(outlier_points[i] - mean_points[j], ord=2) # 计算tooth和GT质心之间的距离
261
+ distance.append(dis)
262
+ min_index = distance.index(min(distance)) # 获取和离群点质心最近label的index
263
+ outlier_label = label_list[min_index] # 获取离群点应该的label
264
+ index = outlier_index[i]
265
+ label_change[index] = outlier_label
266
+
267
+ return label_change
268
+
269
+ # 利用knn算法消除离群点
270
+ def remove_outlier(points, labels):
271
+ # points = np.array(point_cloud_o3d_orign.points)
272
+ # global label_list
273
+ same_label_points = {}
274
+
275
+ same_label_index = {}
276
+
277
+ mean_points = [] # 所有label种类对应点云的质心坐标
278
+
279
+ label_list = []
280
+ for i in range(len(labels)):
281
+ label_list.append(labels[i])
282
+ label_list = list(set(label_list)) # 去重获从小到大排序取GT_label=[0, 11, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26, 27]
283
+ label_list.sort()
284
+ label_list = label_list[1:]
285
+
286
+ for i in label_list:
287
+ key = i
288
+ points_list = []
289
+ all_label_index = []
290
+ for j in range(len(labels)):
291
+ if labels[j] == i:
292
+ points_list.append(points[j].tolist())
293
+ all_label_index.append(j) # 得到label为 i 的点对应的label的下标
294
+ same_label_points[key] = points_list
295
+ same_label_index[key] = all_label_index
296
+
297
+ tooth_mean = np.mean(points_list, axis=0)
298
+ mean_points.append(tooth_mean)
299
+ # print(mean_points)
300
+
301
+ for i in label_list:
302
+ points_array = same_label_points[i]
303
+ # 建立一个o3d的点云对象
304
+ pcd = o3d.geometry.PointCloud()
305
+ # 使用Vector3dVector方法转换
306
+ pcd.points = o3d.utility.Vector3dVector(points_array)
307
+
308
+ # 对label i 对应的点云进行统计离群值去除,找出离群点并显示
309
+ # 统计式离群点移除
310
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=200, std_ratio=2.0) # cl是选中的点,ind是选中点index
311
+ # 可视化
312
+ # display_inlier_outlier(pcd, ind)
313
+
314
+ # 对分出来的离群点重新分类
315
+ label_index = same_label_index[i]
316
+ labels = class_inlier_outlier(label_list, mean_points, pcd, ind, label_index, points, labels)
317
+ # print(f"label_change{labels[4400]}")
318
+
319
+ return labels
320
+
321
+
322
+ # 消除离群点,保存最后的输出
323
+ def remove_outlier_main(jaw, pcd_points, labels, instances_labels):
324
+ # point_cloud_o3d_orign = o3d.io.read_point_cloud('E:/tooth/data/MeshSegNet-master/test_upsample_15/upsample_01K17AN8_upper_refined.pcd')
325
+ # 原始点
326
+ points = pcd_points.copy()
327
+ label = remove_outlier(points, labels)
328
+
329
+ # 保存json文件
330
+ label_dict = {}
331
+ label_dict["id_patient"] = ""
332
+ label_dict["jaw"] = jaw
333
+ label_dict["labels"] = label.tolist()
334
+ label_dict["instances"] = instances_labels.tolist()
335
+ b = json.dumps(label_dict)
336
+ with open('dental-labels4' + '.json', 'w') as f_obj:
337
+ f_obj.write(b)
338
+ f_obj.close()
339
+
340
+
341
+ same_points_list = {}
342
+
343
+
344
+ # 体素下采样
345
+ def voxel_filter(point_cloud, leaf_size):
346
+ same_points_list = {}
347
+ filtered_points = []
348
+ # step1 计算边界点
349
+ x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
350
+ x_min, y_min, z_min = np.amin(point_cloud, axis=0)
351
+
352
+ # step2 确定体素的尺寸
353
+ size_r = leaf_size
354
+
355
+ # step3 计算每个 volex的维度 voxel grid
356
+ Dx = (x_max - x_min) // size_r + 1
357
+ Dy = (y_max - y_min) // size_r + 1
358
+ Dz = (z_max - z_min) // size_r + 1
359
+
360
+ # print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
361
+
362
+ # step4 计算每个点在volex grid内每一个维度的值
363
+ h = list() # h 为保存索引的列表
364
+ for i in range(len(point_cloud)):
365
+ hx = np.floor((point_cloud[i][0] - x_min) // size_r)
366
+ hy = np.floor((point_cloud[i][1] - y_min) // size_r)
367
+ hz = np.floor((point_cloud[i][2] - z_min) // size_r)
368
+ h.append(hx + hy * Dx + hz * Dx * Dy)
369
+ # print(h[60581])
370
+
371
+ # step5 对h值进行排序
372
+ h = np.array(h)
373
+ h_indice = np.argsort(h) # 提取索引,返回h里面的元素按从小到大排序的 索引
374
+ h_sorted = h[h_indice] # 升序
375
+ count = 0 # 用于维度的累计
376
+ step = 20
377
+ # 将h值相同的点放入到同一个grid中,并进行筛选
378
+ for i in range(1, len(h_sorted)): # 0-19999个数据点
379
+ # if i == len(h_sorted)-1:
380
+ # print("aaa")
381
+ if h_sorted[i] == h_sorted[i - 1] and (i != len(h_sorted) - 1):
382
+ continue
383
+ elif h_sorted[i] == h_sorted[i - 1] and (i == len(h_sorted) - 1):
384
+ point_idx = h_indice[count:]
385
+ key = h_sorted[i - 1]
386
+ same_points_list[key] = point_idx
387
+ _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
388
+ _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
389
+ _d.sort()
390
+ inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
391
+ for j in inx:
392
+ index = point_idx[j]
393
+ filtered_points.append(point_cloud[index])
394
+ count = i
395
+ elif h_sorted[i] != h_sorted[i - 1] and (i == len(h_sorted) - 1):
396
+ point_idx1 = h_indice[count:i]
397
+ key1 = h_sorted[i - 1]
398
+ same_points_list[key1] = point_idx1
399
+ _G = np.mean(point_cloud[point_idx1], axis=0) # 所有点的重心
400
+ _d = np.linalg.norm(point_cloud[point_idx1] - _G, axis=1, ord=2) # 计算到重心的距离
401
+ _d.sort()
402
+ inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
403
+ for j in inx:
404
+ index = point_idx1[j]
405
+ filtered_points.append(point_cloud[index])
406
+
407
+ point_idx2 = h_indice[i:]
408
+ key2 = h_sorted[i]
409
+ same_points_list[key2] = point_idx2
410
+ _G = np.mean(point_cloud[point_idx2], axis=0) # 所有点的重心
411
+ _d = np.linalg.norm(point_cloud[point_idx2] - _G, axis=1, ord=2) # 计算到重心的距离
412
+ _d.sort()
413
+ inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
414
+ for j in inx:
415
+ index = point_idx2[j]
416
+ filtered_points.append(point_cloud[index])
417
+ count = i
418
+
419
+ else:
420
+ point_idx = h_indice[count: i]
421
+ key = h_sorted[i - 1]
422
+ same_points_list[key] = point_idx
423
+ _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
424
+ _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
425
+ _d.sort()
426
+ inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
427
+ for j in inx:
428
+ index = point_idx[j]
429
+ filtered_points.append(point_cloud[index])
430
+ count = i
431
+
432
+ # 把点云格式改成array,并对外返回
433
+ # print(f'filtered_points[0]为{filtered_points[0]}')
434
+ filtered_points = np.array(filtered_points, dtype=np.float64)
435
+ return filtered_points,same_points_list
436
+
437
+
438
+ # 体素上采样
439
+ def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels, leaf_size):
440
+ upsample_label = []
441
+ upsample_point = []
442
+ upsample_index = []
443
+ # step1 计算边界点
444
+ x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
445
+ x_min, y_min, z_min = np.amin(point_cloud, axis=0)
446
+ # step2 确定体素的尺寸
447
+ size_r = leaf_size
448
+ # step3 计算每个 volex的维度 voxel grid
449
+ Dx = (x_max - x_min) // size_r + 1
450
+ Dy = (y_max - y_min) // size_r + 1
451
+ Dz = (z_max - z_min) // size_r + 1
452
+ print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
453
+
454
+ # step4 计算每个点(采样后的点)在volex grid内每一个维度的值
455
+ h = list()
456
+ for i in range(len(filtered_points)):
457
+ hx = np.floor((filtered_points[i][0] - x_min) // size_r)
458
+ hy = np.floor((filtered_points[i][1] - y_min) // size_r)
459
+ hz = np.floor((filtered_points[i][2] - z_min) // size_r)
460
+ h.append(hx + hy * Dx + hz * Dx * Dy)
461
+
462
+ # step5 根据h值查询字典same_points_list
463
+ h = np.array(h)
464
+ count = 0
465
+ for i in range(1, len(h)):
466
+ if h[i] == h[i - 1] and i != (len(h) - 1):
467
+ continue
468
+ elif h[i] == h[i - 1] and i == (len(h) - 1):
469
+ label = filter_labels[count:]
470
+ key = h[i - 1]
471
+ count = i
472
+ # 累计label次数,classcount:{‘A’:2,'B':1}
473
+ classcount = {}
474
+ for i in range(len(label)):
475
+ vote = label[i]
476
+ classcount[vote] = classcount.get(vote, 0) + 1
477
+ # 对map的value排序
478
+ sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
479
+ # key = h[i-1]
480
+ point_index = same_points_list[key] # h对应的point index列表
481
+ for j in range(len(point_index)):
482
+ upsample_label.append(sortedclass[0][0])
483
+ index = point_index[j]
484
+ upsample_point.append(point_cloud[index])
485
+ upsample_index.append(index)
486
+ elif h[i] != h[i - 1] and (i == len(h) - 1):
487
+ label1 = filter_labels[count:i]
488
+ key1 = h[i - 1]
489
+ label2 = filter_labels[i:]
490
+ key2 = h[i]
491
+ count = i
492
+
493
+ classcount = {}
494
+ for i in range(len(label1)):
495
+ vote = label1[i]
496
+ classcount[vote] = classcount.get(vote, 0) + 1
497
+ sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
498
+ # key1 = h[i-1]
499
+ point_index = same_points_list[key1]
500
+ for j in range(len(point_index)):
501
+ upsample_label.append(sortedclass[0][0])
502
+ index = point_index[j]
503
+ upsample_point.append(point_cloud[index])
504
+ upsample_index.append(index)
505
+
506
+ # label2 = filter_labels[i:]
507
+ classcount = {}
508
+ for i in range(len(label2)):
509
+ vote = label2[i]
510
+ classcount[vote] = classcount.get(vote, 0) + 1
511
+ sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
512
+ # key2 = h[i]
513
+ point_index = same_points_list[key2]
514
+ for j in range(len(point_index)):
515
+ upsample_label.append(sortedclass[0][0])
516
+ index = point_index[j]
517
+ upsample_point.append(point_cloud[index])
518
+ upsample_index.append(index)
519
+ else:
520
+ label = filter_labels[count:i]
521
+ key = h[i - 1]
522
+ count = i
523
+ classcount = {}
524
+ for i in range(len(label)):
525
+ vote = label[i]
526
+ classcount[vote] = classcount.get(vote, 0) + 1
527
+ sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
528
+ # key = h[i-1]
529
+ point_index = same_points_list[key] # h对应的point index列表
530
+ for j in range(len(point_index)):
531
+ upsample_label.append(sortedclass[0][0])
532
+ index = point_index[j]
533
+ upsample_point.append(point_cloud[index])
534
+ upsample_index.append(index)
535
+ # count = i
536
+
537
+ # 恢复原始顺序
538
+ # print(f'upsample_index[0]的值为{upsample_index[0]}')
539
+ # print(f'upsample_index的总长度为{len(upsample_index)}')
540
+
541
+ # 恢复index原始顺序
542
+ upsample_index = np.array(upsample_index)
543
+ upsample_index_indice = np.argsort(upsample_index) # 提取索引,返回h里面的元素按从小到大排序的 索引
544
+ upsample_index_sorted = upsample_index[upsample_index_indice]
545
+
546
+ upsample_point = np.array(upsample_point)
547
+ upsample_label = np.array(upsample_label)
548
+ # 恢复point和label的原始顺序
549
+ upsample_point_sorted = upsample_point[upsample_index_indice]
550
+ upsample_label_sorted = upsample_label[upsample_index_indice]
551
+
552
+ return upsample_point_sorted, upsample_label_sorted
553
+
554
+
555
+ # 利用knn算法上采样
556
+ def KNN_sklearn_Load_data(voxel_points, center_points, labels):
557
+ # 载入数据
558
+ # x_train, x_test, y_train, y_test = train_test_split(center_points, labels, test_size=0.1)
559
+ # 构建模型
560
+ model = neighbors.KNeighborsClassifier(n_neighbors=3)
561
+ model.fit(center_points, labels)
562
+ prediction = model.predict(voxel_points.reshape(1, -1))
563
+ # meshtopoints_labels = classification_report(voxel_points, prediction)
564
+ return prediction[0]
565
+
566
+
567
+ # 加载点进行knn上采样
568
+ def Load_data(voxel_points, center_points, labels):
569
+ meshtopoints_labels = []
570
+ # meshtopoints_labels.append(SVC_sklearn_Load_data(voxel_points[i], center_points, labels))
571
+ for i in range(0, voxel_points.shape[0]):
572
+ meshtopoints_labels.append(KNN_sklearn_Load_data(voxel_points[i], center_points, labels))
573
+ return np.array(meshtopoints_labels)
574
+
575
+ # 将三角网格数据上采样回原始点云数据
576
+ def mesh_to_points_main(jaw, pcd_points, center_points, labels):
577
+ points = pcd_points.copy()
578
+ # 下采样
579
+ voxel_points, same_points_list = voxel_filter(points, 0.6)
580
+
581
+ after_labels = Load_data(voxel_points, center_points, labels)
582
+
583
+ upsample_point, upsample_label = voxel_upsample(same_points_list, points, voxel_points, after_labels, 0.6)
584
+
585
+ new_pcd = o3d.geometry.PointCloud()
586
+ new_pcd.points = o3d.utility.Vector3dVector(upsample_point)
587
+ instances_labels = upsample_label.copy()
588
+ # '''
589
+ # o3d.io.write_point_cloud(os.path.join(save_path, 'upsample_' + name + '.pcd'), new_pcd, write_ascii=True)
590
+ for i in stqdm(range(0, upsample_label.shape[0])):
591
+ if jaw == 'upper':
592
+ if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
593
+ upsample_label[i] = upsample_label[i] + 10
594
+ elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
595
+ upsample_label[i] = upsample_label[i] + 12
596
+ else:
597
+ if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
598
+ upsample_label[i] = upsample_label[i] + 30
599
+ elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
600
+ upsample_label[i] = upsample_label[i] + 32
601
+ remove_outlier_main(jaw, pcd_points, upsample_label, instances_labels)
602
+
603
+
604
+ # 将原始点云数据转换为三角网格
605
+ def mesh_grid(pcd_points):
606
+ new_pcd,_ = voxel_filter(pcd_points, 0.6)
607
+ # pcd需要有法向量
608
+
609
+ # estimate radius for rolling ball
610
+ pcd_new = o3d.geometry.PointCloud()
611
+ pcd_new.points = o3d.utility.Vector3dVector(new_pcd)
612
+ pcd_new.estimate_normals()
613
+ distances = pcd_new.compute_nearest_neighbor_distance()
614
+ avg_dist = np.mean(distances)
615
+ radius = 6 * avg_dist
616
+ mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
617
+ pcd_new,
618
+ o3d.utility.DoubleVector([radius, radius * 2]))
619
+ # o3d.io.write_triangle_mesh("./tooth date/test.ply", mesh)
620
+
621
+ return mesh
622
+
623
+
624
+ # 读取obj文件内容
625
+ def read_obj(obj_path):
626
+ jaw = None
627
+ with open(obj_path) as file:
628
+ points = []
629
+ faces = []
630
+ while 1:
631
+ line = file.readline()
632
+ if not line:
633
+ break
634
+ strs = line.split(" ")
635
+ if strs[0] == "v":
636
+ points.append((float(strs[1]), float(strs[2]), float(strs[3])))
637
+ elif strs[0] == "f":
638
+ faces.append((int(strs[1]), int(strs[2]), int(strs[3])))
639
+ elif strs[1][0:5] == 'lower':
640
+ jaw = 'lower'
641
+ elif strs[1][0:5] == 'upper':
642
+ jaw = 'upper'
643
+
644
+ points = np.array(points)
645
+ faces = np.array(faces)
646
+
647
+ if jaw is None:
648
+ raise ValueError("Jaw type not found in OBJ file")
649
+
650
+ return points, faces, jaw
651
+
652
+
653
+ # obj文件转为pcd文件
654
+ def obj2pcd(obj_path):
655
+ if os.path.exists(obj_path):
656
+ print('yes')
657
+ points, _, jaw = read_obj(obj_path)
658
+ pcd_list = []
659
+ num_points = np.shape(points)[0]
660
+ for i in range(num_points):
661
+ new_line = str(points[i, 0]) + ' ' + str(points[i, 1]) + ' ' + str(points[i, 2])
662
+ pcd_list.append(new_line.split())
663
+
664
+ pcd_points = np.array(pcd_list).astype(np.float64)
665
+ return pcd_points, jaw
666
+
667
+
668
+ def segmentation_main(obj_path):
669
+ upsampling_method = 'KNN'
670
+
671
+ model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
672
+ num_classes = 17
673
+ num_channels = 15
674
+
675
+ # set model
676
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
677
+ model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
678
+
679
+ # load trained model
680
+ # checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
681
+ checkpoint = torch.load(model_path, map_location='cpu')
682
+ model.load_state_dict(checkpoint['model_state_dict'])
683
+ del checkpoint
684
+ model = model.to(device, dtype=torch.float)
685
+
686
+ # cudnn
687
+ torch.backends.cudnn.benchmark = True
688
+ torch.backends.cudnn.enabled = True
689
+
690
+ # Predicting
691
+ model.eval()
692
+ with torch.no_grad():
693
+ pcd_points, jaw = obj2pcd(obj_path)
694
+ mesh = mesh_grid(pcd_points)
695
+
696
+ # move mesh to origin
697
+ with st.spinner("Patience please, AI at work. Grab a coffee while you wait ☕."):
698
+ vertices_points = np.asarray(mesh.vertices)
699
+ triangles_points = np.asarray(mesh.triangles)
700
+ N = triangles_points.shape[0]
701
+ cells = np.zeros((triangles_points.shape[0], 9))
702
+ cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
703
+
704
+ mean_cell_centers = mesh.get_center()
705
+ cells[:, 0:3] -= mean_cell_centers[0:3]
706
+ cells[:, 3:6] -= mean_cell_centers[0:3]
707
+ cells[:, 6:9] -= mean_cell_centers[0:3]
708
+
709
+ v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
710
+ v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
711
+ v1[:, 0] = cells[:, 0] - cells[:, 3]
712
+ v1[:, 1] = cells[:, 1] - cells[:, 4]
713
+ v1[:, 2] = cells[:, 2] - cells[:, 5]
714
+ v2[:, 0] = cells[:, 3] - cells[:, 6]
715
+ v2[:, 1] = cells[:, 4] - cells[:, 7]
716
+ v2[:, 2] = cells[:, 5] - cells[:, 8]
717
+ mesh_normals = np.cross(v1, v2)
718
+ mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
719
+ mesh_normals[:, 0] /= mesh_normal_length[:]
720
+ mesh_normals[:, 1] /= mesh_normal_length[:]
721
+ mesh_normals[:, 2] /= mesh_normal_length[:]
722
+
723
+ # prepare input
724
+ points = vertices_points.copy()
725
+ points[:, 0:3] -= mean_cell_centers[0:3]
726
+ normals = np.nan_to_num(mesh_normals).copy()
727
+ barycenters = np.zeros((triangles_points.shape[0], 3))
728
+ s = np.sum(vertices_points[triangles_points], 1)
729
+ barycenters = 1 / 3 * s
730
+ center_points = barycenters.copy()
731
+ barycenters -= mean_cell_centers[0:3]
732
+
733
+ # normalized data
734
+ maxs = points.max(axis=0)
735
+ mins = points.min(axis=0)
736
+ means = points.mean(axis=0)
737
+ stds = points.std(axis=0)
738
+ nmeans = normals.mean(axis=0)
739
+ nstds = normals.std(axis=0)
740
+
741
+ for i in range(3):
742
+ cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
743
+ cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
744
+ cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
745
+ barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
746
+ normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
747
+
748
+ X = np.column_stack((cells, barycenters, normals))
749
+
750
+ # computing A_S and A_L
751
+ A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
752
+ A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
753
+ D = distance_matrix(X[:, 9:12], X[:, 9:12])
754
+ A_S[D < 0.1] = 1.0
755
+ A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
756
+
757
+ A_L[D < 0.2] = 1.0
758
+ A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
759
+
760
+ # numpy -> torch.tensor
761
+ X = X.transpose(1, 0)
762
+ X = X.reshape([1, X.shape[0], X.shape[1]])
763
+ X = torch.from_numpy(X).to(device, dtype=torch.float)
764
+ A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
765
+ A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
766
+ A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
767
+ A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
768
+
769
+ tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
770
+ patch_prob_output = tensor_prob_output.cpu().numpy()
771
+
772
+ # refinement
773
+ with st.spinner("Refining..."):
774
+ round_factor = 100
775
+ patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
776
+
777
+ # unaries
778
+ unaries = -round_factor * np.log10(patch_prob_output)
779
+ unaries = unaries.astype(np.int32)
780
+ unaries = unaries.reshape(-1, num_classes)
781
+
782
+ # parawisex
783
+ pairwise = (1 - np.eye(num_classes, dtype=np.int32))
784
+
785
+ cells = cells.copy()
786
+
787
+ cell_ids = np.asarray(triangles_points)
788
+ lambda_c = 20
789
+ edges = np.empty([1, 3], order='C')
790
+ for i_node in stqdm(range(cells.shape[0])):
791
+ # Find neighbors
792
+ nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
793
+ nei_id = np.where(nei == 2)
794
+ for i_nei in nei_id[0][:]:
795
+ if i_node < i_nei:
796
+ cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
797
+ normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
798
+ if cos_theta >= 1.0:
799
+ cos_theta = 0.9999
800
+ theta = np.arccos(cos_theta)
801
+ phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
802
+ if theta > np.pi / 2.0:
803
+ edges = np.concatenate(
804
+ (edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
805
+ else:
806
+ beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
807
+ edges = np.concatenate(
808
+ (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
809
+ axis=0)
810
+ edges = np.delete(edges, 0, 0)
811
+ edges[:, 2] *= lambda_c * round_factor
812
+ edges = edges.astype(np.int32)
813
+
814
+ refine_labels = cut_from_graph(edges, unaries, pairwise)
815
+ refine_labels = refine_labels.reshape([-1, 1])
816
+
817
+ predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
818
+ mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
819
+
820
+ import pyvista as pv
821
+ with st.spinner("Rendering..."):
822
+ # Load the .obj file
823
+ mesh = pv.read('file.obj')
824
+
825
+ # Load the JSON file
826
+ with open('dental-labels4.json', 'r') as file:
827
+ labels_data = json.load(file)
828
+
829
+ # Assuming labels_data['labels'] is a list of labels
830
+ labels = labels_data['labels']
831
+
832
+ # Make sure the number of labels matches the number of vertices or faces
833
+ assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
834
+
835
+ # If labels correspond to vertices
836
+ if len(labels) == mesh.n_points:
837
+ mesh.point_data['Labels'] = labels
838
+ # If labels correspond to faces
839
+ elif len(labels) == mesh.n_cells:
840
+ mesh.cell_data['Labels'] = labels
841
+
842
+ # Create a pyvista plotter
843
+ plotter = pv.Plotter()
844
+
845
+ cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
846
+
847
+ colors = cmap(np.linspace(0, 1, 27)) # Generate colors
848
+
849
+ # Convert colors to a format acceptable by PyVista
850
+ colormap = mcolors.ListedColormap(colors)
851
+
852
+ # Add the mesh to the plotter with labels as a scalar field
853
+ #plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
854
+ plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
855
+
856
+ # Show the plot
857
+ #plotter.show()
858
+ ## Send to streamlit
859
+ with st.expander("**View Segmentation Result** - ", expanded=False):
860
+ stpyvista(plotter)
861
+
862
  # Configure Streamlit page
863
+ st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
864
 
865
+ class Segment(TeethApp):
866
  def __init__(self):
867
  TeethApp.__init__(self)
868
  self.build_app()
869
 
870
  def build_app(self):
871
+
872
+ st.title("Segment Intra-oral Scans")
873
+ st.markdown("Identify and segment teeth. Segmentation is performed using MeshSegNet, a deep learning model trained on both upper and lower jaws.")
874
+
875
+ inputs = st.radio(
876
+ "Select scan for segmentation:",
877
+ ("Upload Scan", "Example Scan"),
878
+ )
879
+ import pyvista as pv
880
+ if inputs == "Example Scan":
881
+ st.markdown("Expected time per prediction: 7-10 min.")
882
+ mesh = pv.read("ZOUIF2W4_upper.obj")
883
+ plotter = pv.Plotter()
884
+
885
+ # Add the mesh to the plotter
886
+ plotter.add_mesh(mesh, color='white', show_edges=False)
887
+ segment = st.button(
888
+ "✔️ Submit",
889
+ help="Submit 3D scan for segmentation",
890
+ )
891
+ with st.expander("View Scan", expanded=False):
892
+ stpyvista(plotter)
893
+
894
+ if segment:
895
+ segmentation_main("ZOUIF2W4_upper.obj")
896
+
897
+
898
+
899
+ elif inputs == "Upload Scan":
900
+ file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
901
+ st.markdown("Expected time per prediction: 7-10 min.")
902
+ if file is not None:
903
+ # save the uploaded file to disk
904
+ with open("file.obj", "wb") as buffer:
905
+ shutil.copyfileobj(file, buffer)
906
+ # 复制数据
907
+ obj_path = "file.obj"
908
+
909
+ mesh = pv.read(obj_path)
910
+ plotter = pv.Plotter()
911
+
912
+ # Add the mesh to the plotter
913
+ plotter.add_mesh(mesh, color='white', show_edges=False)
914
+ segment = st.button(
915
+ "✔️ Submit",
916
+ help="Submit 3D scan for segmentation",
917
+ )
918
+ with st.expander("View Scan", expanded=False):
919
+ stpyvista(plotter)
920
+
921
+ if segment:
922
+ segmentation_main(obj_path)
923
+
924
+
925
+
926
+
927
 
928
  if __name__ == "__main__":
929
+ app = Segment()