Spaces:
Sleeping
Sleeping
""" | |
Re rank api | |
LiteLLM supports the re rank API format, no paramter transformation occurs | |
""" | |
from typing import Any, Dict, List, Optional, Union | |
import litellm | |
from litellm.llms.base import BaseLLM | |
from litellm.llms.custom_httpx.http_handler import ( | |
_get_httpx_client, | |
get_async_httpx_client, | |
) | |
from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig | |
from litellm.types.rerank import RerankRequest, RerankResponse | |
class TogetherAIRerank(BaseLLM): | |
def rerank( | |
self, | |
model: str, | |
api_key: str, | |
query: str, | |
documents: List[Union[str, Dict[str, Any]]], | |
top_n: Optional[int] = None, | |
rank_fields: Optional[List[str]] = None, | |
return_documents: Optional[bool] = True, | |
max_chunks_per_doc: Optional[int] = None, | |
_is_async: Optional[bool] = False, | |
) -> RerankResponse: | |
client = _get_httpx_client() | |
request_data = RerankRequest( | |
model=model, | |
query=query, | |
top_n=top_n, | |
documents=documents, | |
rank_fields=rank_fields, | |
return_documents=return_documents, | |
) | |
# exclude None values from request_data | |
request_data_dict = request_data.dict(exclude_none=True) | |
if max_chunks_per_doc is not None: | |
raise ValueError("TogetherAI does not support max_chunks_per_doc") | |
if _is_async: | |
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method | |
response = client.post( | |
"https://api.together.xyz/v1/rerank", | |
headers={ | |
"accept": "application/json", | |
"content-type": "application/json", | |
"authorization": f"Bearer {api_key}", | |
}, | |
json=request_data_dict, | |
) | |
if response.status_code != 200: | |
raise Exception(response.text) | |
_json_response = response.json() | |
return TogetherAIRerankConfig()._transform_response(_json_response) | |
async def async_rerank( # New async method | |
self, | |
request_data_dict: Dict[str, Any], | |
api_key: str, | |
) -> RerankResponse: | |
client = get_async_httpx_client( | |
llm_provider=litellm.LlmProviders.TOGETHER_AI | |
) # Use async client | |
response = await client.post( | |
"https://api.together.xyz/v1/rerank", | |
headers={ | |
"accept": "application/json", | |
"content-type": "application/json", | |
"authorization": f"Bearer {api_key}", | |
}, | |
json=request_data_dict, | |
) | |
if response.status_code != 200: | |
raise Exception(response.text) | |
_json_response = response.json() | |
return TogetherAIRerankConfig()._transform_response(_json_response) | |