Spaces:
Sleeping
Sleeping
File size: 3,251 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os.path as osp
import torch
from typing import Any
from src.models.model import ModelForSemiStructQA
from src.models.vss import VSS
from src.tools.api import get_openai_embeddings
from src.tools.process_text import chunk_text
class MultiVSS(ModelForSemiStructQA):
def __init__(self,
kb,
query_emb_dir,
candidates_emb_dir,
chunk_emb_dir,
emb_model='text-embedding-ada-002',
aggregate='top3_avg',
max_k=50,
chunk_size=256):
'''
Multivector Vector Similarity Search
Args:
kb (src.benchmarks.semistruct.SemiStruct): kb
query_emb_dir (str): directory to query embeddings
candidates_emb_dir (str): directory to candidate embeddings
chunk_emb_dir (str): directory to chunk embeddings
'''
super().__init__(kb)
self.kb = kb
self.aggregate = aggregate # 'max', 'avg', 'top{k}_avg'
self.max_k = max_k
self.chunk_size = chunk_size
self.emb_model = emb_model
self.query_emb_dir = query_emb_dir
self.chunk_emb_dir = chunk_emb_dir
self.candidates_emb_dir = candidates_emb_dir
self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir, emb_model=emb_model)
def forward(self,
query,
query_id,
**kwargs: Any):
query_emb = self._get_query_emb(query, query_id)
initial_score_dict = self.parent_vss(query, query_id)
node_ids = list(initial_score_dict.keys())
node_scores = list(initial_score_dict.values())
# get the ids with top k highest scores
top_k_idx = torch.topk(torch.FloatTensor(node_scores),
min(self.max_k, len(node_scores)),
dim=-1
).indices.view(-1).tolist()
top_k_node_ids = [node_ids[i] for i in top_k_idx]
pred_dict = {}
for node_id in top_k_node_ids:
doc = self.kb.get_doc_info(node_id, add_rel=True, compact=True)
chunks = chunk_text(doc, chunk_size=self.chunk_size)
chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt')
if osp.exists(chunk_path):
chunk_embs = torch.load(chunk_path)
else:
chunk_embs = get_openai_embeddings(chunks,
model=self.emb_model)
torch.save(chunk_embs, chunk_path)
print(f'chunk_embs.shape: {chunk_embs.shape}')
similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1)
if self.aggregate == 'max':
pred_dict[node_id] = torch.max(similarity).item()
elif self.aggregate == 'avg':
pred_dict[node_id] = torch.mean(similarity).item()
elif 'top' in self.aggregate:
k = int(self.aggregate.split('_')[0][len('top'):])
pred_dict[node_id] = torch.mean(torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values).item()
return pred_dict |