XiaoyangChen commited on
Commit
676b3ba
·
1 Parent(s): c37d123

Add application file

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import os
5
+ import numpy as np
6
+
7
+ import os
8
+
9
+ # Additional import for gradio
10
+ import gradio as gr
11
+ import open3d as o3d
12
+ import plotly.graph_objects as go
13
+ import time
14
+
15
+ import logging
16
+
17
+
18
+ def farthest_point_sample(point, npoint):
19
+ """
20
+ Input:
21
+ xyz: pointcloud data, [N, D]
22
+ npoint: number of samples
23
+ Return:
24
+ centroids: sampled pointcloud index, [npoint, D]
25
+ """
26
+ N, D = point.shape
27
+ xyz = point[:,:3]
28
+ centroids = np.zeros((npoint,))
29
+ distance = np.ones((N,)) * 1e10
30
+ farthest = np.random.randint(0, N)
31
+ for i in range(npoint):
32
+ centroids[i] = farthest
33
+ centroid = xyz[farthest, :]
34
+ dist = np.sum((xyz - centroid) ** 2, -1)
35
+ mask = dist < distance
36
+ distance[mask] = dist[mask]
37
+ farthest = np.argmax(distance, -1)
38
+ point = point[centroids.astype(np.int32)]
39
+ return point
40
+
41
+ def pc_norm(pc):
42
+ """ pc: NxC, return NxC """
43
+ xyz = pc[:, :3]
44
+ other_feature = pc[:, 3:]
45
+
46
+ centroid = np.mean(xyz, axis=0)
47
+ xyz = xyz - centroid
48
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
49
+ xyz = xyz / m
50
+
51
+ pc = np.concatenate((xyz, other_feature), axis=1)
52
+ return pc
53
+
54
+ def change_input_method(input_method):
55
+ if input_method == 'File':
56
+ result = [gr.update(visible=True),
57
+ gr.update(visible=False)]
58
+ elif input_method == 'Object ID':
59
+ result = [gr.update(visible=False),
60
+ gr.update(visible=True)]
61
+ return result
62
+
63
+
64
+ def start_conversation(args):
65
+ print("[INFO] Starting conversation...")
66
+ logging.warning("Starting conversation...")
67
+ while True:
68
+ print("-" * 80)
69
+ logging.warning("-" * 80)
70
+
71
+ # Reset the conversation template
72
+ # conv.reset()
73
+
74
+ def confirm_point_cloud(point_cloud_input, answer_time):
75
+ objects = None
76
+ data = None
77
+ # object_id_input = object_id_input.strip()
78
+
79
+ print("%" * 80)
80
+ logging.warning("%" * 80)
81
+
82
+
83
+ file = point_cloud_input.name
84
+ print(f"Uploading file: {file}.")
85
+ logging.warning(f"Uploading file: {file}.")
86
+ print("%" * 80)
87
+ logging.warning("%" * 80)
88
+
89
+ manual_no_color = "no_color" in file
90
+
91
+ try:
92
+ if '.ply' in file:
93
+ pcd = o3d.io.read_point_cloud(file)
94
+ points = np.asarray(pcd.points) # xyz
95
+ colors = np.asarray(pcd.colors) # rgb, if available
96
+ # * if no colors actually, empty array
97
+ if colors.size == 0:
98
+ colors = None
99
+ elif '.npy' in file:
100
+ data = np.load(file)
101
+ if data.shape[1] >= 3:
102
+ points = data[:, :3]
103
+ else:
104
+ raise ValueError("Input array has the wrong shape. Expected: [N, 3]. Got: {}.".format(data.shape))
105
+ colors = None if data.shape[1] < 6 else data[:, 3:6]
106
+ else:
107
+ raise ValueError("Not supported data format.")
108
+ # error
109
+ except Exception as e:
110
+ print(f"[ERROR] {e}")
111
+ logging.warning(f"[ERROR] {e}")
112
+
113
+ return None, None, answer_time, None
114
+
115
+ if manual_no_color:
116
+ colors = None
117
+
118
+ if colors is not None:
119
+ # * if colors in range(0-1)
120
+ if np.max(colors) <= 1:
121
+ color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)
122
+ # * if colors in range(0-255)
123
+ elif np.max(colors) <= 255:
124
+ color_data = colors.astype(int)
125
+ else:
126
+ color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available
127
+ colors = color_data.astype(np.float32) / 255 # model input is (0-1)
128
+
129
+ # Convert the RGB color data to a list of RGB strings in the format 'rgb(r, g, b)'
130
+ color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]
131
+
132
+ fig = go.Figure(
133
+ data=[
134
+ go.Scatter3d(
135
+ x=points[:, 0], y=points[:, 1], z=points[:, 2],
136
+ mode='markers',
137
+ marker=dict(
138
+ size=1.2,
139
+ color=color_strings, # Use the list of RGB strings for the marker colors
140
+ )
141
+ )
142
+ ],
143
+ layout=dict(
144
+ scene=dict(
145
+ xaxis=dict(visible=False),
146
+ yaxis=dict(visible=False),
147
+ zaxis=dict(visible=False)
148
+ ),
149
+ paper_bgcolor='rgb(255,255,255)' # Set the background color to dark gray 50, 50, 50
150
+ ),
151
+ )
152
+
153
+ points = np.concatenate((points, colors), axis=1)
154
+ if 8192 < points.shape[0]:
155
+ points = farthest_point_sample(points, 8192)
156
+ point_clouds = pc_norm(points)
157
+ point_clouds = torch.from_numpy(point_clouds).unsqueeze_(0).to(torch.float32)
158
+
159
+ answer_time = 0
160
+
161
+ return fig, answer_time, point_clouds
162
+
163
+
164
+ with gr.Blocks() as demo:
165
+ answer_time = gr.State(value=0)
166
+ point_clouds = gr.State(value=None)
167
+ # conv_state = gr.State(value=conv.copy())
168
+ gr.Markdown(
169
+ """
170
+ # PointCloud Visualization 👀
171
+ """
172
+ )
173
+ with gr.Row():
174
+ with gr.Column():
175
+ point_cloud_input = gr.File(visible = True, label="Upload Point Cloud File (PLY, NPY)")
176
+ output = gr.Plot()
177
+ btn = gr.Button(value="Confirm Point Cloud")
178
+
179
+ btn.click(confirm_point_cloud, inputs=[point_cloud_input, answer_time], outputs=[output, answer_time, point_clouds])
180
+ # input_choice.change(change_input_method, input_choice, [point_cloud_input, object_id_input])
181
+ # run_button.click(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1, answer_time, answer_time)
182
+
183
+ demo.queue()
184
+ demo.launch(server_port=args.port, share=True) # server_port=7832, share=True
185
+
186
+ if __name__ == "__main__":
187
+ # ! To release this demo in public, make sure to start in a place where no important data is stored.
188
+ # ! Please check 1. the lanuch dir 2. the tmp dir (GRADIO_TEMP_DIR)
189
+ # ! refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access
190
+ parser = argparse.ArgumentParser()
191
+ parser.add_argument("--model-name", type=str, \
192
+ default="RunsenXu/PointLLM_7B_v1.2")
193
+
194
+
195
+ parser.add_argument("--data_path", type=str, default="data/objaverse_data", required=False)
196
+ parser.add_argument("--pointnum", type=int, default=8192)
197
+
198
+ parser.add_argument("--log_file", type=str, default="serving_workdirs/serving_log.txt")
199
+ parser.add_argument("--tmp_dir", type=str, default="serving_workdirs/tmp")
200
+
201
+ # For gradio
202
+ parser.add_argument("--port", type=int, default=7810)
203
+
204
+ args = parser.parse_args()
205
+
206
+ # * make serving dirs
207
+ os.makedirs(os.path.dirname(args.log_file), exist_ok=True)
208
+ os.makedirs(args.tmp_dir, exist_ok=True)
209
+
210
+ # * add the current time for log name
211
+ args.log_file = args.log_file.replace(".txt", f"_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.txt")
212
+
213
+ logging.basicConfig(
214
+ filename=args.log_file,
215
+ level=logging.WARNING, # * default gradio is info, so use warning
216
+ format='%(asctime)s - %(message)s',
217
+ datefmt='%Y-%m-%d %H:%M:%S'
218
+ )
219
+
220
+ logging.warning("-----New Run-----")
221
+ logging.warning(f"args: {args}")
222
+
223
+ print("-----New Run-----")
224
+ print(f"[INFO] Args: {args}")
225
+
226
+ # * set env variable GRADIO_TEMP_DIR to args.tmp_dir
227
+ os.environ["GRADIO_TEMP_DIR"] = args.tmp_dir
228
+
229
+ # model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
230
+ start_conversation(args)