Spaces:
Sleeping
Sleeping
File size: 1,878 Bytes
55dcb09 483cfa2 55dcb09 483cfa2 55dcb09 77e656b 55dcb09 1499e69 55dcb09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import json
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
meta = json.load(
open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", repo_type='dataset'))
)
# {
# "u": "94db219c315742909fee67deeeacae15",
# "name": "knife",
# "like": 0,
# "view": 35,
# "anims": 0,
# "tags": ["game-ready"],
# "cats": ["weapons-military"],
# "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
# "desc": "",
# "faces": 1724,
# "size": 11955,
# "lic": "by",
# "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
# }
meta = {x['u']: x for x in meta['entries']}
"""
deser = torch.load(
hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu'
)
"""
deser = torch.load(
hf_hub_download("TripletMix/tripletmix-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu'
)
us = deser['us']
feats = deser['feats']
def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
sims = []
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
#for chunk in torch.split(feats, 10240):
for chunk in feats:
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
#sims = torch.cat(sims)
sims = torch.tensor(sims)
sims, idx = torch.sort(sims, descending=True)
sim_mask = sims > sim_th
sims = sims[sim_mask]
idx = idx[sim_mask]
results = []
for i, sim in zip(idx, sims):
if us[i] in meta:
if filter_fn is None or filter_fn(meta[us[i]]):
results.append(dict(meta[us[i]], sim=sim))
if len(results) >= top:
break
return results
|