Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,969 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
from typing import Any
from src.models.vss import VSS
from src.models.model import ModelForSemiStructQA
from src.tools.api import get_llm_output
import re
def find_floating_number(text):
pattern = r'0\.\d+|1\.0'
matches = re.findall(pattern, text)
return [round(float(match), 4) for match in matches if float(match) <= 1.1]
class LLMReranker(ModelForSemiStructQA):
def __init__(self,
kb,
llm_model,
emb_model,
query_emb_dir,
candidates_emb_dir,
sim_weight=0.1,
max_cnt=3,
max_k=100
):
'''
Answer the query by GPT model.
Args:
kb (src.benchmarks.semistruct.SemiStruct): kb
llm_model (str): model name
query_emb_dir (str): directory to query embeddings
candidates_emb_dir (str): directory to candidate embeddings
'''
super(LLMReranker, self).__init__(kb)
self.max_k = max_k
self.emb_model = emb_model
self.llm_model = llm_model
self.sim_weight = sim_weight
self.max_cnt = max_cnt
self.query_emb_dir = query_emb_dir
self.candidates_emb_dir = candidates_emb_dir
self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir, emb_model=emb_model)
def forward(self,
query,
query_id=None,
**kwargs: Any):
initial_score_dict = self.parent_vss(query, query_id)
node_ids = list(initial_score_dict.keys())
node_scores = list(initial_score_dict.values())
# get the ids with top k highest scores
top_k_idx = torch.topk(torch.FloatTensor(node_scores),
min(self.max_k, len(node_scores)),
dim=-1).indices.view(-1).tolist()
top_k_node_ids = [node_ids[i] for i in top_k_idx]
cand_len = len(top_k_node_ids)
pred_dict = {}
for idx, node_id in enumerate(top_k_node_ids):
node_type = self.kb.get_node_type_by_id(node_id)
prompt = (
f'You are a helpful assistant that examines if a {node_type} satisfies a given query and assign a score from 0.0 to 1.0. If the {node_type} does not satisfy the query, the score should be 0.0. If there exists explicit and strong evidence supporting that {node_type} satisfies the query, the score should be 1.0. If partial evidence or weak evidence exists, the score should be between 0.0 and 1.0.\n'
f'Here is the query:\n\"{query}\"\n'
f'Here is the information about the {node_type}:\n' +
self.kb.get_doc_info(node_id, add_rel=True) + '\n\n' +
f'Please score the {node_type} based on how well it satisfies the query. ONLY output the floating point score WITHOUT anything else. '
f'Output: The numeric score of this {node_type} is: '
)
success = False
for _ in range(self.max_cnt):
try:
answer = get_llm_output(prompt,
self.llm_model,
max_tokens=5
)
answer = find_floating_number(answer)
if len(answer) == 1:
answer = answer[0]
success = True
break
except Exception as e:
print(f'Error: {e}, pass')
if success:
llm_score = float(answer)
sim_score = (cand_len - idx) / cand_len
score = llm_score + self.sim_weight * sim_score
pred_dict[node_id] = score
else:
return initial_score_dict
return pred_dict
|