from typing import Literal, Optional, Tuple from .exceptions import DatabricksError class DatabricksBase: def _get_databricks_credentials( self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] ) -> Tuple[str, dict]: headers = headers or {"Content-Type": "application/json"} try: from databricks.sdk import WorkspaceClient databricks_client = WorkspaceClient() api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" if api_key is None: databricks_auth_headers: dict[str, str] = ( databricks_client.config.authenticate() ) headers = {**databricks_auth_headers, **headers} return api_base, headers except ImportError: raise DatabricksError( status_code=400, message=( "If the Databricks base URL and API key are not set, the databricks-sdk " "Python library must be installed. Please install the databricks-sdk, set " "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " "or provide the base URL and API key as arguments." ), ) def databricks_validate_environment( self, api_key: Optional[str], api_base: Optional[str], endpoint_type: Literal["chat_completions", "embeddings"], custom_endpoint: Optional[bool], headers: Optional[dict], ) -> Tuple[str, dict]: if api_key is None and headers is None: if custom_endpoint is not None: raise DatabricksError( status_code=400, message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", ) else: api_base, headers = self._get_databricks_credentials( api_base=api_base, api_key=api_key, headers=headers ) if api_base is None: if custom_endpoint: raise DatabricksError( status_code=400, message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", ) else: api_base, headers = self._get_databricks_credentials( api_base=api_base, api_key=api_key, headers=headers ) if headers is None: headers = { "Authorization": "Bearer {}".format(api_key), "Content-Type": "application/json", } else: if api_key is not None: headers.update({"Authorization": "Bearer {}".format(api_key)}) if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" if endpoint_type == "chat_completions" and custom_endpoint is not True: api_base = "{}/chat/completions".format(api_base) elif endpoint_type == "embeddings" and custom_endpoint is not True: api_base = "{}/embeddings".format(api_base) return api_base, headers