|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj |
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException |
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig |
|
from litellm.secret_managers.main import get_secret_str |
|
from litellm.types.rerank import OptionalRerankParams, RerankRequest |
|
from litellm.types.utils import RerankResponse |
|
|
|
from ..common_utils import CohereError |
|
|
|
|
|
class CohereRerankConfig(BaseRerankConfig): |
|
""" |
|
Reference: https://docs.cohere.com/v2/reference/rerank |
|
""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str: |
|
if api_base: |
|
|
|
api_base = api_base.rstrip("/") |
|
if not api_base.endswith("/v1/rerank"): |
|
api_base = f"{api_base}/v1/rerank" |
|
return api_base |
|
return "https://api.cohere.ai/v1/rerank" |
|
|
|
def get_supported_cohere_rerank_params(self, model: str) -> list: |
|
return [ |
|
"query", |
|
"documents", |
|
"top_n", |
|
"max_chunks_per_doc", |
|
"rank_fields", |
|
"return_documents", |
|
] |
|
|
|
def map_cohere_rerank_params( |
|
self, |
|
non_default_params: Optional[dict], |
|
model: str, |
|
drop_params: bool, |
|
query: str, |
|
documents: List[Union[str, Dict[str, Any]]], |
|
custom_llm_provider: Optional[str] = None, |
|
top_n: Optional[int] = None, |
|
rank_fields: Optional[List[str]] = None, |
|
return_documents: Optional[bool] = True, |
|
max_chunks_per_doc: Optional[int] = None, |
|
) -> OptionalRerankParams: |
|
""" |
|
Map Cohere rerank params |
|
|
|
No mapping required - returns all supported params |
|
""" |
|
return OptionalRerankParams( |
|
query=query, |
|
documents=documents, |
|
top_n=top_n, |
|
rank_fields=rank_fields, |
|
return_documents=return_documents, |
|
max_chunks_per_doc=max_chunks_per_doc, |
|
) |
|
|
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
api_key: Optional[str] = None, |
|
) -> dict: |
|
if api_key is None: |
|
api_key = ( |
|
get_secret_str("COHERE_API_KEY") |
|
or get_secret_str("CO_API_KEY") |
|
or litellm.cohere_key |
|
) |
|
|
|
if api_key is None: |
|
raise ValueError( |
|
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'" |
|
) |
|
|
|
default_headers = { |
|
"Authorization": f"bearer {api_key}", |
|
"accept": "application/json", |
|
"content-type": "application/json", |
|
} |
|
|
|
|
|
if "Authorization" in headers: |
|
default_headers["Authorization"] = headers["Authorization"] |
|
|
|
|
|
return {**default_headers, **headers} |
|
|
|
def transform_rerank_request( |
|
self, |
|
model: str, |
|
optional_rerank_params: OptionalRerankParams, |
|
headers: dict, |
|
) -> dict: |
|
if "query" not in optional_rerank_params: |
|
raise ValueError("query is required for Cohere rerank") |
|
if "documents" not in optional_rerank_params: |
|
raise ValueError("documents is required for Cohere rerank") |
|
rerank_request = RerankRequest( |
|
model=model, |
|
query=optional_rerank_params["query"], |
|
documents=optional_rerank_params["documents"], |
|
top_n=optional_rerank_params.get("top_n", None), |
|
rank_fields=optional_rerank_params.get("rank_fields", None), |
|
return_documents=optional_rerank_params.get("return_documents", None), |
|
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None), |
|
) |
|
return rerank_request.model_dump(exclude_none=True) |
|
|
|
def transform_rerank_response( |
|
self, |
|
model: str, |
|
raw_response: httpx.Response, |
|
model_response: RerankResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
api_key: Optional[str] = None, |
|
request_data: dict = {}, |
|
optional_params: dict = {}, |
|
litellm_params: dict = {}, |
|
) -> RerankResponse: |
|
""" |
|
Transform Cohere rerank response |
|
|
|
No transformation required, litellm follows cohere API response format |
|
""" |
|
try: |
|
raw_response_json = raw_response.json() |
|
except Exception: |
|
raise CohereError( |
|
message=raw_response.text, status_code=raw_response.status_code |
|
) |
|
|
|
return RerankResponse(**raw_response_json) |
|
|
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] |
|
) -> BaseLLMException: |
|
return CohereError(message=error_message, status_code=status_code) |
|
|