|
import json |
|
import torch |
|
import torch.nn.functional as F |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
meta = json.load( |
|
open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", repo_type='dataset')) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meta = {x['u']: x for x in meta['entries']} |
|
deser = torch.load( |
|
hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu' |
|
) |
|
us = deser['us'] |
|
feats = deser['feats'] |
|
|
|
|
|
def retrieve(embedding, top, sim_th=0.0, filter_fn=None): |
|
sims = [] |
|
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze() |
|
for chunk in torch.split(feats, 10240): |
|
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T) |
|
sims = torch.cat(sims) |
|
sims, idx = torch.sort(sims, descending=True) |
|
sim_mask = sims > sim_th |
|
sims = sims[sim_mask] |
|
idx = idx[sim_mask] |
|
results = [] |
|
for i, sim in zip(idx, sims): |
|
if us[i] in meta: |
|
if filter_fn is None or filter_fn(meta[us[i]]): |
|
results.append(dict(meta[us[i]], sim=sim)) |
|
if len(results) >= top: |
|
break |
|
return results |
|
|