Spaces:
Sleeping
Sleeping
File size: 2,265 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import os.path as osp
import torch
from typing import Any
from src.models.model import ModelForSemiStructQA
from tqdm import tqdm
class VSS(ModelForSemiStructQA):
def __init__(self,
kb,
query_emb_dir,
candidates_emb_dir,
emb_model='text-embedding-ada-002'):
'''
Vector Similarity Search
Args:
kb (src.benchmarks.semistruct.SemiStruct): kb
query_emb_dir (str): directory to query embeddings
candidates_emb_dir (str): directory to candidate embeddings
'''
super(VSS, self).__init__(kb)
self.emb_model = emb_model
self.query_emb_dir = query_emb_dir
self.candidates_emb_dir = candidates_emb_dir
candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt')
if osp.exists(candidate_emb_path):
candidate_emb_dict = torch.load(candidate_emb_path)
print(f'Loaded candidate_emb_dict from {candidate_emb_path}!')
else:
print('Loading candidate embeddings...')
candidate_emb_dict = {}
for idx in tqdm(self.candidate_ids):
candidate_emb_dict[idx] = torch.load(osp.join(candidates_emb_dir, f'{idx}.pt'))
torch.save(candidate_emb_dict, candidate_emb_path)
print(f'Saved candidate_emb_dict to {candidate_emb_path}!')
assert len(candidate_emb_dict) == len(self.candidate_ids)
candidate_embs = [candidate_emb_dict[idx] for idx in self.candidate_ids]
self.candidate_embs = torch.cat(candidate_embs, dim=0)
def forward(self,
query: str,
query_id: int,
**kwargs: Any):
query_emb = self._get_query_emb(query,
query_id,
emb_model=self.emb_model
)
similarity = torch.matmul(query_emb.cuda(),
self.candidate_embs.cuda().T
).cpu().view(-1)
pred_dict = {self.candidate_ids[i]: similarity[i] for i in range(len(self.candidate_ids))}
return pred_dict |