|
import json |
|
from typing import List, Literal, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
AsyncHTTPHandler, |
|
HTTPHandler, |
|
get_async_httpx_client, |
|
) |
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( |
|
VertexAIError, |
|
VertexLLM, |
|
) |
|
from litellm.types.llms.vertex_ai import ( |
|
Instance, |
|
InstanceImage, |
|
InstanceVideo, |
|
MultimodalPredictions, |
|
VertexMultimodalEmbeddingRequest, |
|
) |
|
from litellm.types.utils import Embedding, EmbeddingResponse |
|
from litellm.utils import is_base64_encoded |
|
|
|
|
|
class VertexMultimodalEmbedding(VertexLLM): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [ |
|
"multimodalembedding", |
|
"multimodalembedding@001", |
|
] |
|
|
|
def multimodal_embedding( |
|
self, |
|
model: str, |
|
input: Union[list, str], |
|
print_verbose, |
|
model_response: EmbeddingResponse, |
|
custom_llm_provider: Literal["gemini", "vertex_ai"], |
|
optional_params: dict, |
|
logging_obj: LiteLLMLoggingObj, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
encoding=None, |
|
vertex_project=None, |
|
vertex_location=None, |
|
vertex_credentials=None, |
|
aembedding=False, |
|
timeout=300, |
|
client=None, |
|
) -> EmbeddingResponse: |
|
|
|
_auth_header, vertex_project = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider=custom_llm_provider, |
|
) |
|
|
|
auth_header, url = self._get_token_and_url( |
|
model=model, |
|
auth_header=_auth_header, |
|
gemini_api_key=api_key, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
vertex_credentials=vertex_credentials, |
|
stream=None, |
|
custom_llm_provider=custom_llm_provider, |
|
api_base=api_base, |
|
should_use_v1beta1_features=False, |
|
mode="embedding", |
|
) |
|
|
|
if client is None: |
|
_params = {} |
|
if timeout is not None: |
|
if isinstance(timeout, float) or isinstance(timeout, int): |
|
_httpx_timeout = httpx.Timeout(timeout) |
|
_params["timeout"] = _httpx_timeout |
|
else: |
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) |
|
|
|
sync_handler: HTTPHandler = HTTPHandler(**_params) |
|
else: |
|
sync_handler = client |
|
|
|
optional_params = optional_params or {} |
|
|
|
request_data = VertexMultimodalEmbeddingRequest() |
|
|
|
if "instances" in optional_params: |
|
request_data["instances"] = optional_params["instances"] |
|
elif isinstance(input, list): |
|
vertex_instances: List[Instance] = self.process_openai_embedding_input( |
|
_input=input |
|
) |
|
request_data["instances"] = vertex_instances |
|
|
|
else: |
|
|
|
vertex_request_instance = Instance(**optional_params) |
|
|
|
if isinstance(input, str): |
|
vertex_request_instance = self._process_input_element(input) |
|
|
|
request_data["instances"] = [vertex_request_instance] |
|
|
|
headers = { |
|
"Content-Type": "application/json; charset=utf-8", |
|
"Authorization": f"Bearer {auth_header}", |
|
} |
|
|
|
|
|
logging_obj.pre_call( |
|
input=input, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": request_data, |
|
"api_base": url, |
|
"headers": headers, |
|
}, |
|
) |
|
|
|
if aembedding is True: |
|
return self.async_multimodal_embedding( |
|
model=model, |
|
api_base=url, |
|
data=request_data, |
|
timeout=timeout, |
|
headers=headers, |
|
client=client, |
|
model_response=model_response, |
|
) |
|
|
|
response = sync_handler.post( |
|
url=url, |
|
headers=headers, |
|
data=json.dumps(request_data), |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
_json_response = response.json() |
|
if "predictions" not in _json_response: |
|
raise litellm.InternalServerError( |
|
message=f"embedding response does not contain 'predictions', got {_json_response}", |
|
llm_provider="vertex_ai", |
|
model=model, |
|
) |
|
_predictions = _json_response["predictions"] |
|
vertex_predictions = MultimodalPredictions(predictions=_predictions) |
|
model_response.data = self.transform_embedding_response_to_openai( |
|
predictions=vertex_predictions |
|
) |
|
model_response.model = model |
|
|
|
return model_response |
|
|
|
async def async_multimodal_embedding( |
|
self, |
|
model: str, |
|
api_base: str, |
|
data: VertexMultimodalEmbeddingRequest, |
|
model_response: litellm.EmbeddingResponse, |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
headers={}, |
|
client: Optional[AsyncHTTPHandler] = None, |
|
) -> litellm.EmbeddingResponse: |
|
if client is None: |
|
_params = {} |
|
if timeout is not None: |
|
if isinstance(timeout, float) or isinstance(timeout, int): |
|
timeout = httpx.Timeout(timeout) |
|
_params["timeout"] = timeout |
|
client = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.VERTEX_AI, |
|
params={"timeout": timeout}, |
|
) |
|
else: |
|
client = client |
|
|
|
try: |
|
response = await client.post(api_base, headers=headers, json=data) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise VertexAIError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise VertexAIError(status_code=408, message="Timeout error occurred.") |
|
|
|
_json_response = response.json() |
|
if "predictions" not in _json_response: |
|
raise litellm.InternalServerError( |
|
message=f"embedding response does not contain 'predictions', got {_json_response}", |
|
llm_provider="vertex_ai", |
|
model=model, |
|
) |
|
_predictions = _json_response["predictions"] |
|
|
|
vertex_predictions = MultimodalPredictions(predictions=_predictions) |
|
model_response.data = self.transform_embedding_response_to_openai( |
|
predictions=vertex_predictions |
|
) |
|
model_response.model = model |
|
|
|
return model_response |
|
|
|
def _process_input_element(self, input_element: str) -> Instance: |
|
""" |
|
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text. |
|
|
|
Args: |
|
input_element (str): The input element to process. |
|
|
|
Returns: |
|
Dict[str, Any]: A dictionary representing the processed input element. |
|
""" |
|
if len(input_element) == 0: |
|
return Instance(text=input_element) |
|
elif "gs://" in input_element: |
|
if "mp4" in input_element: |
|
return Instance(video=InstanceVideo(gcsUri=input_element)) |
|
else: |
|
return Instance(image=InstanceImage(gcsUri=input_element)) |
|
elif is_base64_encoded(s=input_element): |
|
return Instance(image=InstanceImage(bytesBase64Encoded=input_element)) |
|
else: |
|
return Instance(text=input_element) |
|
|
|
def process_openai_embedding_input( |
|
self, _input: Union[list, str] |
|
) -> List[Instance]: |
|
""" |
|
Process the input for multimodal embedding requests. |
|
|
|
Args: |
|
_input (Union[list, str]): The input data to process. |
|
|
|
Returns: |
|
List[Instance]: A list of processed VertexAI Instance objects. |
|
""" |
|
|
|
_input_list = None |
|
if not isinstance(_input, list): |
|
_input_list = [_input] |
|
else: |
|
_input_list = _input |
|
|
|
processed_instances = [] |
|
for element in _input_list: |
|
if isinstance(element, str): |
|
instance = Instance(**self._process_input_element(element)) |
|
elif isinstance(element, dict): |
|
instance = Instance(**element) |
|
else: |
|
raise ValueError(f"Unsupported input type: {type(element)}") |
|
processed_instances.append(instance) |
|
|
|
return processed_instances |
|
|
|
def transform_embedding_response_to_openai( |
|
self, predictions: MultimodalPredictions |
|
) -> List[Embedding]: |
|
|
|
openai_embeddings: List[Embedding] = [] |
|
if "predictions" in predictions: |
|
for idx, _prediction in enumerate(predictions["predictions"]): |
|
if _prediction: |
|
if "textEmbedding" in _prediction: |
|
openai_embedding_object = Embedding( |
|
embedding=_prediction["textEmbedding"], |
|
index=idx, |
|
object="embedding", |
|
) |
|
openai_embeddings.append(openai_embedding_object) |
|
elif "imageEmbedding" in _prediction: |
|
openai_embedding_object = Embedding( |
|
embedding=_prediction["imageEmbedding"], |
|
index=idx, |
|
object="embedding", |
|
) |
|
openai_embeddings.append(openai_embedding_object) |
|
elif "videoEmbeddings" in _prediction: |
|
for video_embedding in _prediction["videoEmbeddings"]: |
|
openai_embedding_object = Embedding( |
|
embedding=video_embedding["embedding"], |
|
index=idx, |
|
object="embedding", |
|
) |
|
openai_embeddings.append(openai_embedding_object) |
|
return openai_embeddings |
|
|