test3 / litellm /llms /huggingface /rerank /transformation.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
import httpx
import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import (
OptionalRerankParams,
RerankBilledUnits,
RerankResponse,
RerankResponseDocument,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
from litellm.utils import token_counter
from ..common_utils import HuggingFaceError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class HuggingFaceRerankResponseItem(TypedDict):
"""Type definition for HuggingFace rerank API response items."""
index: int
score: float
text: Optional[str] # Optional, included when return_text=True
class HuggingFaceRerankResponse(TypedDict):
"""Type definition for HuggingFace rerank API complete response."""
# The response is a list of HuggingFaceRerankResponseItem
pass
# Type alias for the actual response structure
HuggingFaceRerankResponseList = List[HuggingFaceRerankResponseItem]
class HuggingFaceRerankConfig(BaseRerankConfig):
def get_api_base(self, model: str, api_base: Optional[str]) -> str:
if api_base is not None:
return api_base
elif os.getenv("HF_API_BASE") is not None:
return os.getenv("HF_API_BASE", "")
elif os.getenv("HUGGINGFACE_API_BASE") is not None:
return os.getenv("HUGGINGFACE_API_BASE", "")
else:
return "https://api-inference.huggingface.co"
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
"""
Get the complete URL for the API call, including the /rerank suffix if necessary.
"""
# Get base URL from api_base or default
base_url = self.get_api_base(model=model, api_base=api_base)
# Remove trailing slashes and ensure we have the /rerank endpoint
base_url = base_url.rstrip("/")
if not base_url.endswith("/rerank"):
base_url = f"{base_url}/rerank"
return base_url
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
optional_rerank_params = {}
if non_default_params is not None:
for k, v in non_default_params.items():
if k == "documents" and v is not None:
optional_rerank_params["texts"] = v
elif k == "return_documents" and v is not None and isinstance(v, bool):
optional_rerank_params["return_text"] = v
elif k == "top_n" and v is not None:
optional_rerank_params["top_n"] = v
elif k == "documents" and v is not None:
optional_rerank_params["texts"] = v
elif k == "query" and v is not None:
optional_rerank_params["query"] = v
return OptionalRerankParams(**optional_rerank_params) # type: ignore
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
# Get API credentials
api_key, api_base = self.get_api_credentials(api_key=api_key, api_base=api_base)
default_headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
default_headers["Authorization"] = f"Bearer {api_key}"
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: Union[OptionalRerankParams, dict],
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for HuggingFace rerank")
if "texts" not in optional_rerank_params:
raise ValueError(
"Cohere 'documents' param is required for HuggingFace rerank"
)
# Ensure return_text is a boolean value
# HuggingFace API expects return_text parameter, corresponding to our return_documents parameter
request_body = {
"raw_scores": False,
"truncate": False,
"truncation_direction": "Right",
}
request_body.update(optional_rerank_params)
return request_body
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LoggingClass,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
try:
raw_response_json: HuggingFaceRerankResponseList = raw_response.json()
except Exception:
raise HuggingFaceError(
message=getattr(raw_response, "text", str(raw_response)),
status_code=getattr(raw_response, "status_code", 500),
)
# Use standard litellm token counter for proper token estimation
input_text = request_data.get("query", "")
try:
# Calculate tokens for the raw response JSON string
response_text = str(raw_response_json)
estimated_output_tokens = token_counter(model=model, text=response_text)
# Calculate input tokens from query and documents
query = request_data.get("query", "")
documents = request_data.get("texts", [])
# Convert documents to string if they're not already
documents_text = ""
for doc in documents:
if isinstance(doc, str):
documents_text += doc + " "
elif isinstance(doc, dict) and "text" in doc:
documents_text += doc["text"] + " "
# Calculate input tokens using the same model
input_text = query + " " + documents_text
estimated_input_tokens = token_counter(model=model, text=input_text)
except Exception:
# Fallback to reasonable estimates if token counting fails
estimated_output_tokens = (
len(raw_response_json) * 10 if raw_response_json else 10
)
estimated_input_tokens = (
len(input_text) * 4 if "input_text" in locals() else 0
)
_billed_units = RerankBilledUnits(search_units=1)
_tokens = RerankTokens(
input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens
)
rerank_meta = RerankResponseMeta(
api_version={"version": "1.0"}, billed_units=_billed_units, tokens=_tokens
)
# Check if documents should be returned based on request parameters
should_return_documents = request_data.get(
"return_text", False
) or request_data.get("return_documents", False)
original_documents = request_data.get("texts", [])
results = []
for item in raw_response_json:
# Extract required fields with defaults to handle None values
index = item.get("index")
score = item.get("score")
# Skip items that don't have required fields
if index is None or score is None:
continue
# Create RerankResponseResult with required fields
result = RerankResponseResult(index=index, relevance_score=score)
# Add optional document field if needed
if should_return_documents:
text_content = item.get("text", "")
# 1. First try to use text returned directly from API if available
if text_content:
result["document"] = RerankResponseDocument(text=text_content)
# 2. If no text in API response but original documents are available, use those
elif original_documents and 0 <= item.get("index", -1) < len(
original_documents
):
doc = original_documents[item.get("index")]
if isinstance(doc, str):
result["document"] = RerankResponseDocument(text=doc)
elif isinstance(doc, dict) and "text" in doc:
result["document"] = RerankResponseDocument(text=doc["text"])
results.append(result)
return RerankResponse(
id=str(uuid.uuid4()),
results=results,
meta=rerank_meta,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(message=error_message, status_code=status_code)
def get_api_credentials(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
Get API key and base URL from multiple sources.
Returns tuple of (api_key, api_base).
Parameters:
api_key: API key provided directly to this function, takes precedence over all other sources
api_base: API base provided directly to this function, takes precedence over all other sources
"""
# Get API key from multiple sources
final_api_key = (
api_key or litellm.huggingface_key or get_secret_str("HUGGINGFACE_API_KEY")
)
# Get API base from multiple sources
final_api_base = (
api_base
or litellm.api_base
or get_secret_str("HF_API_BASE")
or get_secret_str("HUGGINGFACE_API_BASE")
)
return final_api_key, final_api_base