from typing import Any, Dict, List, Optional, Union from litellm.llms.cohere.rerank.transformation import CohereRerankConfig from litellm.types.rerank import OptionalRerankParams, RerankRequest class CohereRerankV2Config(CohereRerankConfig): """ Reference: https://docs.cohere.com/v2/reference/rerank """ def __init__(self) -> None: pass def get_complete_url(self, api_base: Optional[str], model: str) -> str: if api_base: # Remove trailing slashes and ensure clean base URL api_base = api_base.rstrip("/") if not api_base.endswith("/v2/rerank"): api_base = f"{api_base}/v2/rerank" return api_base return "https://api.cohere.ai/v2/rerank" def get_supported_cohere_rerank_params(self, model: str) -> list: return [ "query", "documents", "top_n", "max_tokens_per_doc", "rank_fields", "return_documents", ] def map_cohere_rerank_params( self, non_default_params: Optional[dict], model: str, drop_params: bool, query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[str] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, max_tokens_per_doc: Optional[int] = None, ) -> OptionalRerankParams: """ Map Cohere rerank params No mapping required - returns all supported params """ return OptionalRerankParams( query=query, documents=documents, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_tokens_per_doc=max_tokens_per_doc, ) def transform_rerank_request( self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict, ) -> dict: if "query" not in optional_rerank_params: raise ValueError("query is required for Cohere rerank") if "documents" not in optional_rerank_params: raise ValueError("documents is required for Cohere rerank") rerank_request = RerankRequest( model=model, query=optional_rerank_params["query"], documents=optional_rerank_params["documents"], top_n=optional_rerank_params.get("top_n", None), rank_fields=optional_rerank_params.get("rank_fields", None), return_documents=optional_rerank_params.get("return_documents", None), max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None), ) return rerank_request.model_dump(exclude_none=True)