from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx from litellm.types.rerank import OptionalRerankParams, RerankResponse from ..chat.transformation import BaseLLMException if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class BaseRerankConfig(ABC): @abstractmethod def validate_environment( self, headers: dict, model: str, api_key: Optional[str] = None, ) -> dict: pass @abstractmethod def transform_rerank_request( self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict, ) -> dict: return {} @abstractmethod def transform_rerank_response( self, model: str, raw_response: httpx.Response, model_response: RerankResponse, logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, request_data: dict = {}, optional_params: dict = {}, litellm_params: dict = {}, ) -> RerankResponse: return model_response @abstractmethod def get_complete_url(self, api_base: Optional[str], model: str) -> str: """ OPTIONAL Get the complete url for the request Some providers need `model` in `api_base` """ return api_base or "" @abstractmethod def get_supported_cohere_rerank_params(self, model: str) -> list: pass @abstractmethod def map_cohere_rerank_params( self, non_default_params: Optional[dict], model: str, drop_params: bool, query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[str] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, ) -> OptionalRerankParams: pass @abstractmethod def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: pass