winfred2027 commited on
Commit
55dcb09
·
verified ·
1 Parent(s): d93c0d6

Upload retrieval.py

Browse files
Files changed (1) hide show
  1. demo_support/retrieval.py +50 -0
demo_support/retrieval.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ meta = json.load(
8
+ open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", repo_type='dataset'))
9
+ )
10
+ # {
11
+ # "u": "94db219c315742909fee67deeeacae15",
12
+ # "name": "knife",
13
+ # "like": 0,
14
+ # "view": 35,
15
+ # "anims": 0,
16
+ # "tags": ["game-ready"],
17
+ # "cats": ["weapons-military"],
18
+ # "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
19
+ # "desc": "",
20
+ # "faces": 1724,
21
+ # "size": 11955,
22
+ # "lic": "by",
23
+ # "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
24
+ # }
25
+ meta = {x['u']: x for x in meta['entries']}
26
+ deser = torch.load(
27
+ hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu'
28
+ )
29
+ us = deser['us']
30
+ feats = deser['feats']
31
+
32
+
33
+ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
34
+ sims = []
35
+ embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
36
+ for chunk in torch.split(feats, 10240):
37
+ sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
38
+ sims = torch.cat(sims)
39
+ sims, idx = torch.sort(sims, descending=True)
40
+ sim_mask = sims > sim_th
41
+ sims = sims[sim_mask]
42
+ idx = idx[sim_mask]
43
+ results = []
44
+ for i, sim in zip(idx, sims):
45
+ if us[i] in meta:
46
+ if filter_fn is None or filter_fn(meta[us[i]]):
47
+ results.append(dict(meta[us[i]], sim=sim))
48
+ if len(results) >= top:
49
+ break
50
+ return results