|
from typing import Literal, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
AsyncHTTPHandler, |
|
HTTPHandler, |
|
_get_httpx_client, |
|
get_async_httpx_client, |
|
) |
|
from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError |
|
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase |
|
from litellm.types.llms.vertex_ai import * |
|
from litellm.types.utils import EmbeddingResponse |
|
|
|
from .types import * |
|
|
|
|
|
class VertexEmbedding(VertexBase): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def embedding( |
|
self, |
|
model: str, |
|
input: Union[list, str], |
|
print_verbose, |
|
model_response: EmbeddingResponse, |
|
optional_params: dict, |
|
logging_obj: LiteLLMLoggingObject, |
|
custom_llm_provider: Literal[ |
|
"vertex_ai", "vertex_ai_beta", "gemini" |
|
], |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
api_key: Optional[str] = None, |
|
encoding=None, |
|
aembedding=False, |
|
api_base: Optional[str] = None, |
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, |
|
vertex_project: Optional[str] = None, |
|
vertex_location: Optional[str] = None, |
|
vertex_credentials: Optional[str] = None, |
|
gemini_api_key: Optional[str] = None, |
|
extra_headers: Optional[dict] = None, |
|
) -> EmbeddingResponse: |
|
if aembedding is True: |
|
return self.async_embedding( |
|
model=model, |
|
input=input, |
|
logging_obj=logging_obj, |
|
model_response=model_response, |
|
optional_params=optional_params, |
|
encoding=encoding, |
|
custom_llm_provider=custom_llm_provider, |
|
timeout=timeout, |
|
api_base=api_base, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
vertex_credentials=vertex_credentials, |
|
gemini_api_key=gemini_api_key, |
|
extra_headers=extra_headers, |
|
) |
|
|
|
should_use_v1beta1_features = self.is_using_v1beta1_features( |
|
optional_params=optional_params |
|
) |
|
|
|
_auth_header, vertex_project = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider=custom_llm_provider, |
|
) |
|
auth_header, api_base = self._get_token_and_url( |
|
model=model, |
|
gemini_api_key=gemini_api_key, |
|
auth_header=_auth_header, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
vertex_credentials=vertex_credentials, |
|
stream=False, |
|
custom_llm_provider=custom_llm_provider, |
|
api_base=api_base, |
|
should_use_v1beta1_features=should_use_v1beta1_features, |
|
mode="embedding", |
|
) |
|
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) |
|
vertex_request: VertexEmbeddingRequest = ( |
|
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( |
|
input=input, optional_params=optional_params, model=model |
|
) |
|
) |
|
|
|
_client_params = {} |
|
if timeout: |
|
_client_params["timeout"] = timeout |
|
if client is None or not isinstance(client, HTTPHandler): |
|
client = _get_httpx_client(params=_client_params) |
|
else: |
|
client = client |
|
|
|
logging_obj.pre_call( |
|
input=vertex_request, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": vertex_request, |
|
"api_base": api_base, |
|
"headers": headers, |
|
}, |
|
) |
|
|
|
try: |
|
response = client.post(api_base, headers=headers, json=vertex_request) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise VertexAIError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise VertexAIError(status_code=408, message="Timeout error occurred.") |
|
|
|
_json_response = response.json() |
|
|
|
logging_obj.post_call( |
|
input=input, api_key=None, original_response=_json_response |
|
) |
|
|
|
model_response = ( |
|
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( |
|
response=_json_response, model=model, model_response=model_response |
|
) |
|
) |
|
|
|
return model_response |
|
|
|
async def async_embedding( |
|
self, |
|
model: str, |
|
input: Union[list, str], |
|
model_response: litellm.EmbeddingResponse, |
|
logging_obj: LiteLLMLoggingObject, |
|
optional_params: dict, |
|
custom_llm_provider: Literal[ |
|
"vertex_ai", "vertex_ai_beta", "gemini" |
|
], |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
api_base: Optional[str] = None, |
|
client: Optional[AsyncHTTPHandler] = None, |
|
vertex_project: Optional[str] = None, |
|
vertex_location: Optional[str] = None, |
|
vertex_credentials: Optional[str] = None, |
|
gemini_api_key: Optional[str] = None, |
|
extra_headers: Optional[dict] = None, |
|
encoding=None, |
|
) -> litellm.EmbeddingResponse: |
|
""" |
|
Async embedding implementation |
|
""" |
|
should_use_v1beta1_features = self.is_using_v1beta1_features( |
|
optional_params=optional_params |
|
) |
|
_auth_header, vertex_project = await self._ensure_access_token_async( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider=custom_llm_provider, |
|
) |
|
auth_header, api_base = self._get_token_and_url( |
|
model=model, |
|
gemini_api_key=gemini_api_key, |
|
auth_header=_auth_header, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
vertex_credentials=vertex_credentials, |
|
stream=False, |
|
custom_llm_provider=custom_llm_provider, |
|
api_base=api_base, |
|
should_use_v1beta1_features=should_use_v1beta1_features, |
|
mode="embedding", |
|
) |
|
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) |
|
vertex_request: VertexEmbeddingRequest = ( |
|
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( |
|
input=input, optional_params=optional_params, model=model |
|
) |
|
) |
|
|
|
_async_client_params = {} |
|
if timeout: |
|
_async_client_params["timeout"] = timeout |
|
if client is None or not isinstance(client, AsyncHTTPHandler): |
|
client = get_async_httpx_client( |
|
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI |
|
) |
|
else: |
|
client = client |
|
|
|
logging_obj.pre_call( |
|
input=vertex_request, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": vertex_request, |
|
"api_base": api_base, |
|
"headers": headers, |
|
}, |
|
) |
|
|
|
try: |
|
response = await client.post(api_base, headers=headers, json=vertex_request) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise VertexAIError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise VertexAIError(status_code=408, message="Timeout error occurred.") |
|
|
|
_json_response = response.json() |
|
|
|
logging_obj.post_call( |
|
input=input, api_key=None, original_response=_json_response |
|
) |
|
|
|
model_response = ( |
|
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( |
|
response=_json_response, model=model, model_response=model_response |
|
) |
|
) |
|
|
|
return model_response |
|
|