import opik from llm_engineering.application.networks import CrossEncoderModelSingleton from llm_engineering.domain.embedded_chunks import EmbeddedChunk from llm_engineering.domain.queries import Query from .base import RAGStep class Reranker(RAGStep): def __init__(self, mock: bool = False) -> None: super().__init__(mock=mock) self._model = CrossEncoderModelSingleton() @opik.track(name="Reranker.generate") def generate(self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int) -> list[EmbeddedChunk]: if self._mock: return chunks query_doc_tuples = [(query.content, chunk.content) for chunk in chunks] scores = self._model(query_doc_tuples) scored_query_doc_tuples = list(zip(scores, chunks, strict=False)) scored_query_doc_tuples.sort(key=lambda x: x[0], reverse=True) reranked_documents = scored_query_doc_tuples[:keep_top_k] reranked_documents = [doc for _, doc in reranked_documents] return reranked_documents