Spaces:
Build error
Build error
File size: 1,070 Bytes
d660b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import opik
from llm_engineering.application.networks import CrossEncoderModelSingleton
from llm_engineering.domain.embedded_chunks import EmbeddedChunk
from llm_engineering.domain.queries import Query
from .base import RAGStep
class Reranker(RAGStep):
def __init__(self, mock: bool = False) -> None:
super().__init__(mock=mock)
self._model = CrossEncoderModelSingleton()
@opik.track(name="Reranker.generate")
def generate(self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int) -> list[EmbeddedChunk]:
if self._mock:
return chunks
query_doc_tuples = [(query.content, chunk.content) for chunk in chunks]
scores = self._model(query_doc_tuples)
scored_query_doc_tuples = list(zip(scores, chunks, strict=False))
scored_query_doc_tuples.sort(key=lambda x: x[0], reverse=True)
reranked_documents = scored_query_doc_tuples[:keep_top_k]
reranked_documents = [doc for _, doc in reranked_documents]
return reranked_documents
|