import torch from typing import Any from src.models.vss import VSS from src.models.model import ModelForSemiStructQA from src.tools.api import get_llm_output import re def find_floating_number(text): pattern = r'0\.\d+|1\.0' matches = re.findall(pattern, text) return [round(float(match), 4) for match in matches if float(match) <= 1.1] class LLMReranker(ModelForSemiStructQA): def __init__(self, kb, llm_model, emb_model, query_emb_dir, candidates_emb_dir, sim_weight=0.1, max_cnt=3, max_k=100 ): ''' Answer the query by GPT model. Args: kb (src.benchmarks.semistruct.SemiStruct): kb llm_model (str): model name query_emb_dir (str): directory to query embeddings candidates_emb_dir (str): directory to candidate embeddings ''' super(LLMReranker, self).__init__(kb) self.max_k = max_k self.emb_model = emb_model self.llm_model = llm_model self.sim_weight = sim_weight self.max_cnt = max_cnt self.query_emb_dir = query_emb_dir self.candidates_emb_dir = candidates_emb_dir self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir, emb_model=emb_model) def forward(self, query, query_id=None, **kwargs: Any): initial_score_dict = self.parent_vss(query, query_id) node_ids = list(initial_score_dict.keys()) node_scores = list(initial_score_dict.values()) # get the ids with top k highest scores top_k_idx = torch.topk(torch.FloatTensor(node_scores), min(self.max_k, len(node_scores)), dim=-1).indices.view(-1).tolist() top_k_node_ids = [node_ids[i] for i in top_k_idx] cand_len = len(top_k_node_ids) pred_dict = {} for idx, node_id in enumerate(top_k_node_ids): node_type = self.kb.get_node_type_by_id(node_id) prompt = ( f'You are a helpful assistant that examines if a {node_type} satisfies a given query and assign a score from 0.0 to 1.0. If the {node_type} does not satisfy the query, the score should be 0.0. If there exists explicit and strong evidence supporting that {node_type} satisfies the query, the score should be 1.0. If partial evidence or weak evidence exists, the score should be between 0.0 and 1.0.\n' f'Here is the query:\n\"{query}\"\n' f'Here is the information about the {node_type}:\n' + self.kb.get_doc_info(node_id, add_rel=True) + '\n\n' + f'Please score the {node_type} based on how well it satisfies the query. ONLY output the floating point score WITHOUT anything else. ' f'Output: The numeric score of this {node_type} is: ' ) success = False for _ in range(self.max_cnt): try: answer = get_llm_output(prompt, self.llm_model, max_tokens=5 ) answer = find_floating_number(answer) if len(answer) == 1: answer = answer[0] success = True break except Exception as e: print(f'Error: {e}, pass') if success: llm_score = float(answer) sim_score = (cand_len - idx) / cand_len score = llm_score + self.sim_weight * sim_score pred_dict[node_id] = score else: return initial_score_dict return pred_dict