import os import uuid from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union import httpx import litellm from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.secret_managers.main import get_secret_str from litellm.types.rerank import ( OptionalRerankParams, RerankBilledUnits, RerankResponse, RerankResponseDocument, RerankResponseMeta, RerankResponseResult, RerankTokens, ) from litellm.utils import token_counter from ..common_utils import HuggingFaceError if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj LoggingClass = LiteLLMLoggingObj else: LoggingClass = Any class HuggingFaceRerankResponseItem(TypedDict): """Type definition for HuggingFace rerank API response items.""" index: int score: float text: Optional[str] # Optional, included when return_text=True class HuggingFaceRerankResponse(TypedDict): """Type definition for HuggingFace rerank API complete response.""" # The response is a list of HuggingFaceRerankResponseItem pass # Type alias for the actual response structure HuggingFaceRerankResponseList = List[HuggingFaceRerankResponseItem] class HuggingFaceRerankConfig(BaseRerankConfig): def get_api_base(self, model: str, api_base: Optional[str]) -> str: if api_base is not None: return api_base elif os.getenv("HF_API_BASE") is not None: return os.getenv("HF_API_BASE", "") elif os.getenv("HUGGINGFACE_API_BASE") is not None: return os.getenv("HUGGINGFACE_API_BASE", "") else: return "https://api-inference.huggingface.co" def get_complete_url(self, api_base: Optional[str], model: str) -> str: """ Get the complete URL for the API call, including the /rerank suffix if necessary. """ # Get base URL from api_base or default base_url = self.get_api_base(model=model, api_base=api_base) # Remove trailing slashes and ensure we have the /rerank endpoint base_url = base_url.rstrip("/") if not base_url.endswith("/rerank"): base_url = f"{base_url}/rerank" return base_url def get_supported_cohere_rerank_params(self, model: str) -> list: return [ "query", "documents", "top_n", "return_documents", ] def map_cohere_rerank_params( self, non_default_params: Optional[dict], model: str, drop_params: bool, query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[str] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, max_tokens_per_doc: Optional[int] = None, ) -> OptionalRerankParams: optional_rerank_params = {} if non_default_params is not None: for k, v in non_default_params.items(): if k == "documents" and v is not None: optional_rerank_params["texts"] = v elif k == "return_documents" and v is not None and isinstance(v, bool): optional_rerank_params["return_text"] = v elif k == "top_n" and v is not None: optional_rerank_params["top_n"] = v elif k == "documents" and v is not None: optional_rerank_params["texts"] = v elif k == "query" and v is not None: optional_rerank_params["query"] = v return OptionalRerankParams(**optional_rerank_params) # type: ignore def validate_environment( self, headers: dict, model: str, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: # Get API credentials api_key, api_base = self.get_api_credentials(api_key=api_key, api_base=api_base) default_headers = { "accept": "application/json", "content-type": "application/json", } if api_key: default_headers["Authorization"] = f"Bearer {api_key}" if "Authorization" in headers: default_headers["Authorization"] = headers["Authorization"] return {**default_headers, **headers} def transform_rerank_request( self, model: str, optional_rerank_params: Union[OptionalRerankParams, dict], headers: dict, ) -> dict: if "query" not in optional_rerank_params: raise ValueError("query is required for HuggingFace rerank") if "texts" not in optional_rerank_params: raise ValueError( "Cohere 'documents' param is required for HuggingFace rerank" ) # Ensure return_text is a boolean value # HuggingFace API expects return_text parameter, corresponding to our return_documents parameter request_body = { "raw_scores": False, "truncate": False, "truncation_direction": "Right", } request_body.update(optional_rerank_params) return request_body def transform_rerank_response( self, model: str, raw_response: httpx.Response, model_response: RerankResponse, logging_obj: LoggingClass, api_key: Optional[str] = None, request_data: dict = {}, optional_params: dict = {}, litellm_params: dict = {}, ) -> RerankResponse: try: raw_response_json: HuggingFaceRerankResponseList = raw_response.json() except Exception: raise HuggingFaceError( message=getattr(raw_response, "text", str(raw_response)), status_code=getattr(raw_response, "status_code", 500), ) # Use standard litellm token counter for proper token estimation input_text = request_data.get("query", "") try: # Calculate tokens for the raw response JSON string response_text = str(raw_response_json) estimated_output_tokens = token_counter(model=model, text=response_text) # Calculate input tokens from query and documents query = request_data.get("query", "") documents = request_data.get("texts", []) # Convert documents to string if they're not already documents_text = "" for doc in documents: if isinstance(doc, str): documents_text += doc + " " elif isinstance(doc, dict) and "text" in doc: documents_text += doc["text"] + " " # Calculate input tokens using the same model input_text = query + " " + documents_text estimated_input_tokens = token_counter(model=model, text=input_text) except Exception: # Fallback to reasonable estimates if token counting fails estimated_output_tokens = ( len(raw_response_json) * 10 if raw_response_json else 10 ) estimated_input_tokens = ( len(input_text) * 4 if "input_text" in locals() else 0 ) _billed_units = RerankBilledUnits(search_units=1) _tokens = RerankTokens( input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens ) rerank_meta = RerankResponseMeta( api_version={"version": "1.0"}, billed_units=_billed_units, tokens=_tokens ) # Check if documents should be returned based on request parameters should_return_documents = request_data.get( "return_text", False ) or request_data.get("return_documents", False) original_documents = request_data.get("texts", []) results = [] for item in raw_response_json: # Extract required fields with defaults to handle None values index = item.get("index") score = item.get("score") # Skip items that don't have required fields if index is None or score is None: continue # Create RerankResponseResult with required fields result = RerankResponseResult(index=index, relevance_score=score) # Add optional document field if needed if should_return_documents: text_content = item.get("text", "") # 1. First try to use text returned directly from API if available if text_content: result["document"] = RerankResponseDocument(text=text_content) # 2. If no text in API response but original documents are available, use those elif original_documents and 0 <= item.get("index", -1) < len( original_documents ): doc = original_documents[item.get("index")] if isinstance(doc, str): result["document"] = RerankResponseDocument(text=doc) elif isinstance(doc, dict) and "text" in doc: result["document"] = RerankResponseDocument(text=doc["text"]) results.append(result) return RerankResponse( id=str(uuid.uuid4()), results=results, meta=rerank_meta, ) def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: return HuggingFaceError(message=error_message, status_code=status_code) def get_api_credentials( self, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Tuple[Optional[str], Optional[str]]: """ Get API key and base URL from multiple sources. Returns tuple of (api_key, api_base). Parameters: api_key: API key provided directly to this function, takes precedence over all other sources api_base: API base provided directly to this function, takes precedence over all other sources """ # Get API key from multiple sources final_api_key = ( api_key or litellm.huggingface_key or get_secret_str("HUGGINGFACE_API_KEY") ) # Get API base from multiple sources final_api_base = ( api_base or litellm.api_base or get_secret_str("HF_API_BASE") or get_secret_str("HUGGINGFACE_API_BASE") ) return final_api_key, final_api_base