File size: 10,815 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union

import httpx

import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import (
    OptionalRerankParams,
    RerankBilledUnits,
    RerankResponse,
    RerankResponseDocument,
    RerankResponseMeta,
    RerankResponseResult,
    RerankTokens,
)
from litellm.utils import token_counter

from ..common_utils import HuggingFaceError

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

    LoggingClass = LiteLLMLoggingObj
else:
    LoggingClass = Any


class HuggingFaceRerankResponseItem(TypedDict):
    """Type definition for HuggingFace rerank API response items."""

    index: int
    score: float
    text: Optional[str]  # Optional, included when return_text=True


class HuggingFaceRerankResponse(TypedDict):
    """Type definition for HuggingFace rerank API complete response."""

    # The response is a list of HuggingFaceRerankResponseItem
    pass


# Type alias for the actual response structure
HuggingFaceRerankResponseList = List[HuggingFaceRerankResponseItem]


class HuggingFaceRerankConfig(BaseRerankConfig):
    def get_api_base(self, model: str, api_base: Optional[str]) -> str:
        if api_base is not None:
            return api_base
        elif os.getenv("HF_API_BASE") is not None:
            return os.getenv("HF_API_BASE", "")
        elif os.getenv("HUGGINGFACE_API_BASE") is not None:
            return os.getenv("HUGGINGFACE_API_BASE", "")
        else:
            return "https://api-inference.huggingface.co"

    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
        """
        Get the complete URL for the API call, including the /rerank suffix if necessary.
        """
        # Get base URL from api_base or default
        base_url = self.get_api_base(model=model, api_base=api_base)

        # Remove trailing slashes and ensure we have the /rerank endpoint
        base_url = base_url.rstrip("/")
        if not base_url.endswith("/rerank"):
            base_url = f"{base_url}/rerank"

        return base_url

    def get_supported_cohere_rerank_params(self, model: str) -> list:
        return [
            "query",
            "documents",
            "top_n",
            "return_documents",
        ]

    def map_cohere_rerank_params(
        self,
        non_default_params: Optional[dict],
        model: str,
        drop_params: bool,
        query: str,
        documents: List[Union[str, Dict[str, Any]]],
        custom_llm_provider: Optional[str] = None,
        top_n: Optional[int] = None,
        rank_fields: Optional[List[str]] = None,
        return_documents: Optional[bool] = True,
        max_chunks_per_doc: Optional[int] = None,
        max_tokens_per_doc: Optional[int] = None,
    ) -> OptionalRerankParams:
        optional_rerank_params = {}
        if non_default_params is not None:
            for k, v in non_default_params.items():
                if k == "documents" and v is not None:
                    optional_rerank_params["texts"] = v
                elif k == "return_documents" and v is not None and isinstance(v, bool):
                    optional_rerank_params["return_text"] = v
                elif k == "top_n" and v is not None:
                    optional_rerank_params["top_n"] = v
                elif k == "documents" and v is not None:
                    optional_rerank_params["texts"] = v
                elif k == "query" and v is not None:
                    optional_rerank_params["query"] = v

        return OptionalRerankParams(**optional_rerank_params)  # type: ignore

    def validate_environment(
        self,
        headers: dict,
        model: str,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        # Get API credentials
        api_key, api_base = self.get_api_credentials(api_key=api_key, api_base=api_base)

        default_headers = {
            "accept": "application/json",
            "content-type": "application/json",
        }

        if api_key:
            default_headers["Authorization"] = f"Bearer {api_key}"

        if "Authorization" in headers:
            default_headers["Authorization"] = headers["Authorization"]

        return {**default_headers, **headers}

    def transform_rerank_request(
        self,
        model: str,
        optional_rerank_params: Union[OptionalRerankParams, dict],
        headers: dict,
    ) -> dict:
        if "query" not in optional_rerank_params:
            raise ValueError("query is required for HuggingFace rerank")
        if "texts" not in optional_rerank_params:
            raise ValueError(
                "Cohere 'documents' param is required for HuggingFace rerank"
            )
        # Ensure return_text is a boolean value
        # HuggingFace API expects return_text parameter, corresponding to our return_documents parameter
        request_body = {
            "raw_scores": False,
            "truncate": False,
            "truncation_direction": "Right",
        }

        request_body.update(optional_rerank_params)

        return request_body

    def transform_rerank_response(
        self,
        model: str,
        raw_response: httpx.Response,
        model_response: RerankResponse,
        logging_obj: LoggingClass,
        api_key: Optional[str] = None,
        request_data: dict = {},
        optional_params: dict = {},
        litellm_params: dict = {},
    ) -> RerankResponse:
        try:
            raw_response_json: HuggingFaceRerankResponseList = raw_response.json()
        except Exception:
            raise HuggingFaceError(
                message=getattr(raw_response, "text", str(raw_response)),
                status_code=getattr(raw_response, "status_code", 500),
            )

        # Use standard litellm token counter for proper token estimation
        input_text = request_data.get("query", "")
        try:
            # Calculate tokens for the raw response JSON string
            response_text = str(raw_response_json)
            estimated_output_tokens = token_counter(model=model, text=response_text)

            # Calculate input tokens from query and documents
            query = request_data.get("query", "")
            documents = request_data.get("texts", [])

            # Convert documents to string if they're not already
            documents_text = ""
            for doc in documents:
                if isinstance(doc, str):
                    documents_text += doc + " "
                elif isinstance(doc, dict) and "text" in doc:
                    documents_text += doc["text"] + " "

            # Calculate input tokens using the same model
            input_text = query + " " + documents_text
            estimated_input_tokens = token_counter(model=model, text=input_text)
        except Exception:
            # Fallback to reasonable estimates if token counting fails
            estimated_output_tokens = (
                len(raw_response_json) * 10 if raw_response_json else 10
            )
            estimated_input_tokens = (
                len(input_text) * 4 if "input_text" in locals() else 0
            )

        _billed_units = RerankBilledUnits(search_units=1)
        _tokens = RerankTokens(
            input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens
        )
        rerank_meta = RerankResponseMeta(
            api_version={"version": "1.0"}, billed_units=_billed_units, tokens=_tokens
        )

        # Check if documents should be returned based on request parameters
        should_return_documents = request_data.get(
            "return_text", False
        ) or request_data.get("return_documents", False)
        original_documents = request_data.get("texts", [])

        results = []
        for item in raw_response_json:
            # Extract required fields with defaults to handle None values
            index = item.get("index")
            score = item.get("score")

            # Skip items that don't have required fields
            if index is None or score is None:
                continue

            # Create RerankResponseResult with required fields
            result = RerankResponseResult(index=index, relevance_score=score)

            # Add optional document field if needed
            if should_return_documents:
                text_content = item.get("text", "")

                # 1. First try to use text returned directly from API if available
                if text_content:
                    result["document"] = RerankResponseDocument(text=text_content)
                # 2. If no text in API response but original documents are available, use those
                elif original_documents and 0 <= item.get("index", -1) < len(
                    original_documents
                ):
                    doc = original_documents[item.get("index")]
                    if isinstance(doc, str):
                        result["document"] = RerankResponseDocument(text=doc)
                    elif isinstance(doc, dict) and "text" in doc:
                        result["document"] = RerankResponseDocument(text=doc["text"])

            results.append(result)

        return RerankResponse(
            id=str(uuid.uuid4()),
            results=results,
            meta=rerank_meta,
        )

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
    ) -> BaseLLMException:
        return HuggingFaceError(message=error_message, status_code=status_code)

    def get_api_credentials(
        self,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> Tuple[Optional[str], Optional[str]]:
        """
        Get API key and base URL from multiple sources.
        Returns tuple of (api_key, api_base).

        Parameters:
            api_key: API key provided directly to this function, takes precedence over all other sources
            api_base: API base provided directly to this function, takes precedence over all other sources
        """
        # Get API key from multiple sources
        final_api_key = (
            api_key or litellm.huggingface_key or get_secret_str("HUGGINGFACE_API_KEY")
        )

        # Get API base from multiple sources
        final_api_base = (
            api_base
            or litellm.api_base
            or get_secret_str("HF_API_BASE")
            or get_secret_str("HUGGINGFACE_API_BASE")
        )

        return final_api_key, final_api_base