|
from typing import Literal, Optional, Tuple |
|
|
|
from .exceptions import DatabricksError |
|
|
|
|
|
class DatabricksBase: |
|
def _get_databricks_credentials( |
|
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] |
|
) -> Tuple[str, dict]: |
|
headers = headers or {"Content-Type": "application/json"} |
|
try: |
|
from databricks.sdk import WorkspaceClient |
|
|
|
databricks_client = WorkspaceClient() |
|
|
|
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" |
|
|
|
if api_key is None: |
|
databricks_auth_headers: dict[str, str] = ( |
|
databricks_client.config.authenticate() |
|
) |
|
headers = {**databricks_auth_headers, **headers} |
|
|
|
return api_base, headers |
|
except ImportError: |
|
raise DatabricksError( |
|
status_code=400, |
|
message=( |
|
"If the Databricks base URL and API key are not set, the databricks-sdk " |
|
"Python library must be installed. Please install the databricks-sdk, set " |
|
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " |
|
"or provide the base URL and API key as arguments." |
|
), |
|
) |
|
|
|
def databricks_validate_environment( |
|
self, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
endpoint_type: Literal["chat_completions", "embeddings"], |
|
custom_endpoint: Optional[bool], |
|
headers: Optional[dict], |
|
) -> Tuple[str, dict]: |
|
if api_key is None and headers is None: |
|
if custom_endpoint is not None: |
|
raise DatabricksError( |
|
status_code=400, |
|
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", |
|
) |
|
else: |
|
api_base, headers = self._get_databricks_credentials( |
|
api_base=api_base, api_key=api_key, headers=headers |
|
) |
|
|
|
if api_base is None: |
|
if custom_endpoint: |
|
raise DatabricksError( |
|
status_code=400, |
|
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", |
|
) |
|
else: |
|
api_base, headers = self._get_databricks_credentials( |
|
api_base=api_base, api_key=api_key, headers=headers |
|
) |
|
|
|
if headers is None: |
|
headers = { |
|
"Authorization": "Bearer {}".format(api_key), |
|
"Content-Type": "application/json", |
|
} |
|
else: |
|
if api_key is not None: |
|
headers.update({"Authorization": "Bearer {}".format(api_key)}) |
|
|
|
if api_key is not None: |
|
headers["Authorization"] = f"Bearer {api_key}" |
|
|
|
if endpoint_type == "chat_completions" and custom_endpoint is not True: |
|
api_base = "{}/chat/completions".format(api_base) |
|
elif endpoint_type == "embeddings" and custom_endpoint is not True: |
|
api_base = "{}/embeddings".format(api_base) |
|
return api_base, headers |
|
|