TestLLM / litellm /llms /databricks /common_utils.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
from typing import Literal, Optional, Tuple
from .exceptions import DatabricksError
class DatabricksBase:
def _get_databricks_credentials(
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
) -> Tuple[str, dict]:
headers = headers or {"Content-Type": "application/json"}
try:
from databricks.sdk import WorkspaceClient
databricks_client = WorkspaceClient()
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
if api_key is None:
databricks_auth_headers: dict[str, str] = (
databricks_client.config.authenticate()
)
headers = {**databricks_auth_headers, **headers}
return api_base, headers
except ImportError:
raise DatabricksError(
status_code=400,
message=(
"If the Databricks base URL and API key are not set, the databricks-sdk "
"Python library must be installed. Please install the databricks-sdk, set "
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
"or provide the base URL and API key as arguments."
),
)
def databricks_validate_environment(
self,
api_key: Optional[str],
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
custom_endpoint: Optional[bool],
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
if custom_endpoint is not None:
raise DatabricksError(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
else:
api_base, headers = self._get_databricks_credentials(
api_base=api_base, api_key=api_key, headers=headers
)
if api_base is None:
if custom_endpoint:
raise DatabricksError(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
else:
api_base, headers = self._get_databricks_credentials(
api_base=api_base, api_key=api_key, headers=headers
)
if headers is None:
headers = {
"Authorization": "Bearer {}".format(api_key),
"Content-Type": "application/json",
}
else:
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings" and custom_endpoint is not True:
api_base = "{}/embeddings".format(api_base)
return api_base, headers