File size: 1,552 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""Patched version of FlashRankRanker that fixes incorrect reranking [1].

[1] https://github.com/AnswerDotAI/rerankers/issues/39
"""

import contextlib
from io import StringIO
from typing import Any

from flashrank import RerankRequest

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
with contextlib.redirect_stdout(StringIO()):
    from rerankers.documents import Document
    from rerankers.models.flashrank_ranker import FlashRankRanker
    from rerankers.results import RankedResults, Result
    from rerankers.utils import prep_docs


class PatchedFlashRankRanker(FlashRankRanker):
    def rank(
        self,
        query: str,
        docs: str | list[str] | Document | list[Document],
        doc_ids: list[str] | list[int] | None = None,
        metadata: list[dict[str, Any]] | None = None,
    ) -> RankedResults:
        docs = prep_docs(docs, doc_ids, metadata)
        passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)]
        rerank_request = RerankRequest(query=query, passages=passages)
        flashrank_results = self.model.rerank(rerank_request)
        ranked_results = [
            Result(
                document=docs[result["id"]],  # This patches the incorrect ranking in the original.
                score=result["score"],
                rank=idx + 1,
            )
            for idx, result in enumerate(flashrank_results)
        ]
        return RankedResults(results=ranked_results, query=query, has_scores=True)