File size: 1,878 Bytes
55dcb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483cfa2
55dcb09
 
 
483cfa2
 
 
 
55dcb09
 
 
 
 
 
 
77e656b
 
55dcb09
1499e69
 
55dcb09
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import json
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download


meta = json.load(
    open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", repo_type='dataset'))
)
# {
#     "u": "94db219c315742909fee67deeeacae15",
#     "name": "knife",
#     "like": 0,
#     "view": 35,
#     "anims": 0,
#     "tags": ["game-ready"],
#     "cats": ["weapons-military"],
#     "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
#     "desc": "",
#     "faces": 1724,
#     "size": 11955,
#     "lic": "by",
#     "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
# }
meta = {x['u']: x for x in meta['entries']}
"""
deser = torch.load(
    hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu'
)
"""
deser = torch.load(
    hf_hub_download("TripletMix/tripletmix-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu'
)
us = deser['us']
feats = deser['feats']


def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
    sims = []
    embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
    #for chunk in torch.split(feats, 10240):
    for chunk in feats:
        sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
    #sims = torch.cat(sims)
    sims = torch.tensor(sims) 
    sims, idx = torch.sort(sims, descending=True)
    sim_mask = sims > sim_th
    sims = sims[sim_mask]
    idx = idx[sim_mask]
    results = []
    for i, sim in zip(idx, sims):
        if us[i] in meta:
            if filter_fn is None or filter_fn(meta[us[i]]):
                results.append(dict(meta[us[i]], sim=sim))
                if len(results) >= top:
                    break
    return results