Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os.path as osp | |
import torch | |
from typing import Any | |
from src.models.model import ModelForSemiStructQA | |
from src.models.vss import VSS | |
from src.tools.api import get_openai_embeddings | |
from src.tools.process_text import chunk_text | |
class MultiVSS(ModelForSemiStructQA): | |
def __init__(self, | |
kb, | |
query_emb_dir, | |
candidates_emb_dir, | |
chunk_emb_dir, | |
emb_model='text-embedding-ada-002', | |
aggregate='top3_avg', | |
max_k=50, | |
chunk_size=256): | |
''' | |
Multivector Vector Similarity Search | |
Args: | |
kb (src.benchmarks.semistruct.SemiStruct): kb | |
query_emb_dir (str): directory to query embeddings | |
candidates_emb_dir (str): directory to candidate embeddings | |
chunk_emb_dir (str): directory to chunk embeddings | |
''' | |
super().__init__(kb) | |
self.kb = kb | |
self.aggregate = aggregate # 'max', 'avg', 'top{k}_avg' | |
self.max_k = max_k | |
self.chunk_size = chunk_size | |
self.emb_model = emb_model | |
self.query_emb_dir = query_emb_dir | |
self.chunk_emb_dir = chunk_emb_dir | |
self.candidates_emb_dir = candidates_emb_dir | |
self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir, emb_model=emb_model) | |
def forward(self, | |
query, | |
query_id, | |
**kwargs: Any): | |
query_emb = self._get_query_emb(query, query_id) | |
initial_score_dict = self.parent_vss(query, query_id) | |
node_ids = list(initial_score_dict.keys()) | |
node_scores = list(initial_score_dict.values()) | |
# get the ids with top k highest scores | |
top_k_idx = torch.topk(torch.FloatTensor(node_scores), | |
min(self.max_k, len(node_scores)), | |
dim=-1 | |
).indices.view(-1).tolist() | |
top_k_node_ids = [node_ids[i] for i in top_k_idx] | |
pred_dict = {} | |
for node_id in top_k_node_ids: | |
doc = self.kb.get_doc_info(node_id, add_rel=True, compact=True) | |
chunks = chunk_text(doc, chunk_size=self.chunk_size) | |
chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt') | |
if osp.exists(chunk_path): | |
chunk_embs = torch.load(chunk_path) | |
else: | |
chunk_embs = get_openai_embeddings(chunks, | |
model=self.emb_model) | |
torch.save(chunk_embs, chunk_path) | |
print(f'chunk_embs.shape: {chunk_embs.shape}') | |
similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1) | |
if self.aggregate == 'max': | |
pred_dict[node_id] = torch.max(similarity).item() | |
elif self.aggregate == 'avg': | |
pred_dict[node_id] = torch.mean(similarity).item() | |
elif 'top' in self.aggregate: | |
k = int(self.aggregate.split('_')[0][len('top'):]) | |
pred_dict[node_id] = torch.mean(torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values).item() | |
return pred_dict |