Spaces:
Configuration error
Configuration error
File size: 5,913 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import logging
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
def _build_chat_completion_url(model_url: str) -> str:
# Strip trailing /
model_url = model_url.rstrip("/")
# Append /chat/completions if not already present
if model_url.endswith("/v1"):
model_url += "/chat/completions"
# Append /v1/chat/completions if not already present
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"
return model_url
class HuggingFaceChatConfig(OpenAIGPTConfig):
"""
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
"""
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"content-type": "application/json",
}
if api_key is not None:
default_headers["Authorization"] = f"Bearer {api_key}"
headers = {**headers, **default_headers}
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
Get the API base for the Huggingface API.
Do not add the chat/embedding/rerank extension here. Let the handler do this.
"""
if model.startswith(("http://", "https://")):
base_url = model
elif base_url is None:
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
return base_url
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
For provider-specific routing through huggingface
"""
# Check if api_base is provided
if api_base is not None:
complete_url = api_base
complete_url = _build_chat_completion_url(complete_url)
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(
os.getenv("HUGGINGFACE_API_BASE")
)
elif model.startswith(("http://", "https://")):
complete_url = model
complete_url = _build_chat_completion_url(complete_url)
# Default construction with provider
else:
# Parse provider and model
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
else:
provider = "hf-inference"
if provider == "hf-inference":
route = f"{provider}/models/{model}/v1/chat/completions"
elif provider == "novita":
route = f"{provider}/v3/openai/chat/completions"
elif provider == "fireworks-ai":
route = f"{provider}/inference/v1/chat/completions"
else:
route = f"{provider}/v1/chat/completions"
complete_url = f"{BASE_URL}/{route}"
# Ensure URL doesn't end with a slash
complete_url = complete_url.rstrip("/")
return complete_url
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
if litellm_params.get("api_base"):
return dict(
ChatCompletionRequest(model=model, messages=messages, **optional_params)
)
if "max_retries" in optional_params:
logger.warning("`max_retries` is not supported. It will be ignored.")
optional_params.pop("max_retries", None)
first_part, remaining = model.split("/", 1)
if "/" in remaining:
provider = first_part
model_id = remaining
else:
provider = "hf-inference"
model_id = model
provider_mapping = _fetch_inference_provider_mapping(model_id)
if provider not in provider_mapping:
raise HuggingFaceError(
message=f"Model {model_id} is not supported for provider {provider}",
status_code=404,
headers={},
)
provider_mapping = provider_mapping[provider]
if provider_mapping["status"] == "staging":
logger.warning(
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(
ChatCompletionRequest(
model=mapped_model, messages=messages, **optional_params
)
)
|