Raju2024's picture
Upload 1072 files
e3278e4 verified
import json
from typing import List, Literal, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
)
from litellm.types.llms.vertex_ai import (
Instance,
InstanceImage,
InstanceVideo,
MultimodalPredictions,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import Embedding, EmbeddingResponse
from litellm.utils import is_base64_encoded
class VertexMultimodalEmbedding(VertexLLM):
def __init__(self) -> None:
super().__init__()
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest()
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
vertex_instances: List[Instance] = self.process_openai_embedding_input(
_input=input
)
request_data["instances"] = vertex_instances
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance = self._process_input_element(input)
request_data["instances"] = [vertex_request_instance]
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_multimodal_embedding( # type: ignore
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
return model_response
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
data: VertexMultimodalEmbeddingRequest,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
return model_response
def _process_input_element(self, input_element: str) -> Instance:
"""
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
Args:
input_element (str): The input element to process.
Returns:
Dict[str, Any]: A dictionary representing the processed input element.
"""
if len(input_element) == 0:
return Instance(text=input_element)
elif "gs://" in input_element:
if "mp4" in input_element:
return Instance(video=InstanceVideo(gcsUri=input_element))
else:
return Instance(image=InstanceImage(gcsUri=input_element))
elif is_base64_encoded(s=input_element):
return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
else:
return Instance(text=input_element)
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
"""
Process the input for multimodal embedding requests.
Args:
_input (Union[list, str]): The input data to process.
Returns:
List[Instance]: A list of processed VertexAI Instance objects.
"""
_input_list = None
if not isinstance(_input, list):
_input_list = [_input]
else:
_input_list = _input
processed_instances = []
for element in _input_list:
if isinstance(element, str):
instance = Instance(**self._process_input_element(element))
elif isinstance(element, dict):
instance = Instance(**element)
else:
raise ValueError(f"Unsupported input type: {type(element)}")
processed_instances.append(instance)
return processed_instances
def transform_embedding_response_to_openai(
self, predictions: MultimodalPredictions
) -> List[Embedding]:
openai_embeddings: List[Embedding] = []
if "predictions" in predictions:
for idx, _prediction in enumerate(predictions["predictions"]):
if _prediction:
if "textEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["textEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "imageEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["imageEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "videoEmbeddings" in _prediction:
for video_embedding in _prediction["videoEmbeddings"]:
openai_embedding_object = Embedding(
embedding=video_embedding["embedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
return openai_embeddings