winfred2027 commited on
Commit
daed59e
·
verified ·
1 Parent(s): b88294b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -7,7 +7,8 @@ import openshape
7
  import transformers
8
  from PIL import Image
9
  from huggingface_hub import HfFolder, snapshot_download
10
- from demo_support import retrieval, utils
 
11
 
12
  @st.cache_resource
13
  def load_openclip():
@@ -99,6 +100,14 @@ def demo_retrieval():
99
  col2 = utils.render_pc(pc)
100
  ref_dev = next(model_g14.parameters()).device
101
  enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
 
 
 
 
 
 
 
 
102
 
103
  prog.progress(0.7, "Running Retrieval")
104
  retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
 
7
  import transformers
8
  from PIL import Image
9
  from huggingface_hub import HfFolder, snapshot_download
10
+ from demo_support import retrieval, utils, lvis
11
+ from collections import OrderedDict
12
 
13
  @st.cache_resource
14
  def load_openclip():
 
100
  col2 = utils.render_pc(pc)
101
  ref_dev = next(model_g14.parameters()).device
102
  enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
103
+
104
+ sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze())
105
+ argsort = torch.argsort(sim, descending=True)
106
+ pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
107
+ with col2:
108
+ for i, (cat, sim) in zip(range(5), pred.items()):
109
+ st.text(cat)
110
+ st.caption("Similarity %.4f" % sim)
111
 
112
  prog.progress(0.7, "Running Retrieval")
113
  retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))