Shyamnath's picture
Push core package and essential files
469eae6
raw
history blame
2.88 kB
from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v2/rerank"):
api_base = f"{api_base}/v2/rerank"
return api_base
return "https://api.cohere.ai/v2/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_tokens_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_tokens_per_doc=max_tokens_per_doc,
)
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)