import numpy as np from models.law_component import LawComponent from sentence_transformers.cross_encoder import CrossEncoder class CrossEncReranker: def __init__(self, model_name, max_length=512): self.model_name = model_name self.reranker = CrossEncoder(self.model_name) self.reranker.max_length = max_length def rerank(self, query_text: str, candidates: list): sentence_combinations = [[query_text, c.text] for c in candidates] similarity_scores = self.reranker.predict(sentence_combinations) index = np.argsort(similarity_scores)[::-1] reranked_candidates = np.array(candidates)[index] return reranked_candidates