winfred2027 commited on
Commit
104f14f
·
verified ·
1 Parent(s): 3f7b866

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -29
app.py CHANGED
@@ -66,17 +66,6 @@ def retrieval_results(results):
66
 
67
 
68
 
69
- def demo_classification():
70
- with st.form("clsform"):
71
- #load_data = misc_utils.input_3d_shape('cls')
72
- cats = st.text_input("Custom Categories (64 max, separated with comma)")
73
- cats = [a.strip() for a in cats.split(',')]
74
- if len(cats) > 64:
75
- st.error('Maximum 64 custom categories supported in the demo')
76
- return
77
- lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
78
- custom_run = st.form_submit_button("Run Classification on Custom Categories")
79
-
80
  def demo_captioning():
81
  with st.form("capform"):
82
  cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')
@@ -146,6 +135,43 @@ def demo_retrieval():
146
  prog.progress(1.0, "Idle")
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def retrieval_pc(load_data, k, sim_th, filter_fn):
150
  pc = load_data(prog)
151
  prog.progress(0.49, "Computing Embeddings")
@@ -208,21 +234,21 @@ try:
208
  'Choose the source of categories',
209
  ("LVIS Categories", "Custom Categories")
210
  )
211
- pc = st.sidebar.text_input("Input pc", key='rtextinput')
212
  if cls_mode == "LVIS Categories":
 
 
213
  if st.sidebar.button("submit"):
214
- st.title("Classification with LVIS Categories")
215
- prog = st.progress(0.0, "Idle")
216
-
217
  elif cls_mode == "Custom Categories":
 
 
218
  cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
219
  cats = [a.strip() for a in cats.split(',')]
220
  if len(cats) > 64:
221
  st.error('Maximum 64 custom categories supported in the demo')
222
  if st.sidebar.button("submit"):
223
- st.title("Classification with Custom Categories")
224
- prog = st.progress(0.0, "Idle")
225
-
226
  elif task == "Cross-modal retrieval":
227
  input_mode = st.sidebar.selectbox(
228
  'Choose an input modality',
@@ -231,22 +257,22 @@ try:
231
  k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
232
  sim_th, filter_fn = retrieval_filter_expand()
233
  if input_mode == "Point Cloud":
 
 
234
  load_data = utils.input_3d_shape('rpcinput')
235
  if st.sidebar.button("submit"):
236
- st.title("Retrieval with Point Cloud")
237
- prog = st.progress(0.0, "Idle")
238
  retrieval_pc(load_data, k, sim_th, filter_fn)
239
  elif input_mode == "Image":
 
 
240
  pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
241
  if st.sidebar.button("submit"):
242
- st.title("Retrieval with Image")
243
- prog = st.progress(0.0, "Idle")
244
  retrieval_img(pic, k, sim_th, filter_fn)
245
  elif input_mode == "Text":
 
 
246
  text = st.sidebar.text_input("Input Text", key='rtextinput')
247
  if st.sidebar.button("submit"):
248
- st.title("Retrieval with Text")
249
- prog = st.progress(0.0, "Idle")
250
  retrieval_text(text, k, sim_th, filter_fn)
251
  elif task == "Cross-modal generation":
252
  generation_mode = st.sidebar.selectbox(
@@ -255,14 +281,15 @@ try:
255
  )
256
  pc = st.sidebar.text_input("Input pc", key='rtextinput')
257
  if generation_mode == "PointCloud-to-Image":
 
 
258
  if st.sidebar.button("submit"):
259
- st.title("Image Generation")
260
- prog = st.progress(0.0, "Idle")
261
-
262
  elif generation_mode == "PointCloud-to-Text":
 
 
263
  if st.sidebar.button("submit"):
264
- st.title("Text Generation")
265
- prog = st.progress(0.0, "Idle")
266
 
267
  except Exception:
268
  import traceback
 
66
 
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
69
  def demo_captioning():
70
  with st.form("capform"):
71
  cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')
 
135
  prog.progress(1.0, "Idle")
136
 
137
 
138
+ def classification_lvis(load_data):
139
+ pc = load_data(prog)
140
+ col2 = utils.render_pc(pc)
141
+ prog.progress(0.5, "Running Classification")
142
+ ref_dev = next(model_g14.parameters()).device
143
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
144
+
145
+ sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
146
+ argsort = torch.argsort(sim, descending=True)
147
+ pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
148
+ with col2:
149
+ for i, (cat, sim) in zip(range(5), pred.items()):
150
+ st.text(cat)
151
+ st.caption("Similarity %.4f" % sim)
152
+ prog.progress(1.0, "Idle")
153
+
154
+ def classification_custom(load_data, cats):
155
+ pc = load_data(prog)
156
+ col2 = utils.render_pc(pc)
157
+ prog.progress(0.5, "Computing Category Embeddings")
158
+ device = clip_model.device
159
+ tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device)
160
+ feats = clip_model.get_text_features(**tn).float().cpu()
161
+
162
+ prog.progress(0.5, "Running Classification")
163
+ ref_dev = next(model_g14.parameters()).device
164
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
165
+ sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
166
+ argsort = torch.argsort(sim, descending=True)
167
+ pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
168
+ with col2:
169
+ for i, (cat, sim) in zip(range(5), pred.items()):
170
+ st.text(cat)
171
+ st.caption("Similarity %.4f" % sim)
172
+ prog.progress(1.0, "Idle")
173
+
174
+
175
  def retrieval_pc(load_data, k, sim_th, filter_fn):
176
  pc = load_data(prog)
177
  prog.progress(0.49, "Computing Embeddings")
 
234
  'Choose the source of categories',
235
  ("LVIS Categories", "Custom Categories")
236
  )
237
+ load_data = utils.input_3d_shape('rpcinput')
238
  if cls_mode == "LVIS Categories":
239
+ st.title("Classification with LVIS Categories")
240
+ prog = st.progress(0.0, "Idle")
241
  if st.sidebar.button("submit"):
242
+ classification_lvis(load_data)
 
 
243
  elif cls_mode == "Custom Categories":
244
+ st.title("Classification with Custom Categories")
245
+ prog = st.progress(0.0, "Idle")
246
  cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
247
  cats = [a.strip() for a in cats.split(',')]
248
  if len(cats) > 64:
249
  st.error('Maximum 64 custom categories supported in the demo')
250
  if st.sidebar.button("submit"):
251
+ classification_custom(load_data, cats)
 
 
252
  elif task == "Cross-modal retrieval":
253
  input_mode = st.sidebar.selectbox(
254
  'Choose an input modality',
 
257
  k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
258
  sim_th, filter_fn = retrieval_filter_expand()
259
  if input_mode == "Point Cloud":
260
+ st.title("Retrieval with Point Cloud")
261
+ prog = st.progress(0.0, "Idle")
262
  load_data = utils.input_3d_shape('rpcinput')
263
  if st.sidebar.button("submit"):
 
 
264
  retrieval_pc(load_data, k, sim_th, filter_fn)
265
  elif input_mode == "Image":
266
+ st.title("Retrieval with Image")
267
+ prog = st.progress(0.0, "Idle")
268
  pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
269
  if st.sidebar.button("submit"):
 
 
270
  retrieval_img(pic, k, sim_th, filter_fn)
271
  elif input_mode == "Text":
272
+ st.title("Retrieval with Text")
273
+ prog = st.progress(0.0, "Idle")
274
  text = st.sidebar.text_input("Input Text", key='rtextinput')
275
  if st.sidebar.button("submit"):
 
 
276
  retrieval_text(text, k, sim_th, filter_fn)
277
  elif task == "Cross-modal generation":
278
  generation_mode = st.sidebar.selectbox(
 
281
  )
282
  pc = st.sidebar.text_input("Input pc", key='rtextinput')
283
  if generation_mode == "PointCloud-to-Image":
284
+ st.title("Image Generation")
285
+ prog = st.progress(0.0, "Idle")
286
  if st.sidebar.button("submit"):
287
+ pc = st.text_input("Input pc", key='rtextinput')
 
 
288
  elif generation_mode == "PointCloud-to-Text":
289
+ st.title("Text Generation")
290
+ prog = st.progress(0.0, "Idle")
291
  if st.sidebar.button("submit"):
292
+ pc = st.text_input("Input pc", key='rtextinput')
 
293
 
294
  except Exception:
295
  import traceback