TestLLM / litellm /llms /vertex_ai /vertex_ai_non_gemini.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
import json
import os
import time
from typing import Any, Callable, Optional, cast
import httpx
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.bedrock.common_utils import ModelResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TextStreamer:
"""
Fake streaming iterator for Vertex AI Model Garden calls
"""
def __init__(self, text):
self.text = text.split() # let's assume words as a streaming unit
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index < len(self.text):
result = self.text[self.index]
self.index += 1
return result
else:
raise StopIteration
def __aiter__(self):
return self
async def __anext__(self):
if self.index < len(self.text):
result = self.text[self.index]
self.index += 1
return result
else:
raise StopAsyncIteration # once we run out of data to stream, we raise this error
def _get_client_cache_key(
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
):
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key
def _get_client_from_cache(client_cache_key: str):
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
litellm.in_memory_llm_clients_cache.set_cache(
key=client_cache_key,
value=vertex_llm_model,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
def completion( # noqa: PLR0915
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
):
"""
NON-GEMINI/ANTHROPIC CALLS.
This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN
For Vertex AI Anthropic: `vertex_anthropic.py`
For Gemini: `vertex_httpx.py`
"""
try:
import vertexai
except Exception:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
import google.auth # type: ignore
from google.cloud import aiplatform # type: ignore
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types, # type: ignore
)
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
from vertexai.preview.generative_models import GenerativeModel
from vertexai.preview.language_models import ChatModel, CodeChatModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
_cache_key = _get_client_cache_key(
model=model, vertex_project=vertex_project, vertex_location=vertex_location
)
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
if _vertex_llm_model_object is None:
from google.auth.credentials import Credentials
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project,
location=vertex_location,
credentials=cast(Credentials, creds),
)
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## Process safety settings into format expected by vertex AI
safety_settings = None
if "safety_settings" in optional_params:
safety_settings = optional_params.pop("safety_settings")
if not isinstance(safety_settings, list):
raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts")
safety_settings = [
gapic_content_types.SafetySetting(x) for x in safety_settings
]
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join(
[
message.get("content")
for message in messages
if isinstance(message.get("content", None), str)
]
)
mode = ""
request_str = ""
response_obj = None
instances = None
client_options = {
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
}
fake_stream = False
if (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models:
llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
model
)
mode = "text"
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models:
llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
model
)
mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
fake_stream = True
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
elif model == "private":
mode = "private"
model = optional_params.pop("model_id", None)
# private endpoint requires a dict instead of JSON
instances = [optional_params.copy()]
instances[0]["prompt"] = prompt
llm_model = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
else: # assume vertex model garden on public endpoint
mode = "custom"
instances = [optional_params.copy()]
instances[0]["prompt"] = prompt
instances = [
json_format.ParseDict(instance_dict, Value())
for instance_dict in instances
]
# Will determine the API used based on async parameter
llm_model = None
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion is True:
data = {
"llm_model": llm_model,
"mode": mode,
"prompt": prompt,
"logging_obj": logging_obj,
"request_str": request_str,
"model": model,
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
"vertex_location": vertex_location,
"vertex_project": vertex_project,
"safety_settings": safety_settings,
**optional_params,
}
if optional_params.get("stream", False) is True:
# async streaming
return async_streaming(**data)
return async_completion(**data)
completion_response = None
stream = optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
if mode == "chat":
chat = llm_model.start_chat()
request_str += "chat = llm_model.start_chat()\n"
if fake_stream is not True and stream is True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
if fake_stream is not True and stream is True:
request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = llm_model.predict(prompt, **optional_params).text
elif mode == "custom":
"""
Vertex AI Model Garden
"""
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
llm_model = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response = llm_model.predict(
endpoint=endpoint_path, instances=instances
).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream is True:
response = TextStreamer(completion_response)
return response
elif mode == "private":
"""
Vertex AI Model Garden deployed on private endpoint
"""
if instances is None:
raise ValueError("instances are required for private endpoint")
if llm_model is None:
raise ValueError("Unable to pick client for private endpoint")
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
request_str += f"llm_model.predict(instances={instances})\n"
response = llm_model.predict(instances=instances).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream is True:
response = TextStreamer(completion_response)
return response
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if isinstance(completion_response, litellm.Message):
model_response.choices[0].message = completion_response # type: ignore
elif len(str(completion_response)) > 0:
model_response.choices[0].message.content = str(completion_response) # type: ignore
model_response.created = int(time.time())
model_response.model = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response.choices[0].finish_reason = map_finish_reason(
response_obj.candidates[0].finish_reason.name
)
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
# init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None:
if hasattr(response_obj, "usage_metadata") and hasattr(
response_obj.usage_metadata, "prompt_token_count"
):
prompt_tokens = response_obj.usage_metadata.prompt_token_count
completion_tokens = (
response_obj.usage_metadata.candidates_token_count
)
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
if fake_stream is True and stream is True:
return ModelResponseIterator(model_response)
return model_response
except Exception as e:
if isinstance(e, VertexAIError):
raise e
raise litellm.APIConnectionError(
message=str(e), llm_provider="vertex_ai", model=model
)
async def async_completion( # noqa: PLR0915
llm_model,
mode: str,
prompt: str,
model: str,
messages: list,
model_response: ModelResponse,
request_str: str,
print_verbose: Callable,
logging_obj,
encoding,
client_options=None,
instances=None,
vertex_project=None,
vertex_location=None,
safety_settings=None,
**optional_params,
):
"""
Add support for acompletion calls for gemini-pro
"""
try:
response_obj = None
completion_response = None
if mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "custom":
"""
Vertex AI Model Garden
"""
from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private":
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if isinstance(completion_response, litellm.Message):
model_response.choices[0].message = completion_response # type: ignore
elif len(str(completion_response)) > 0:
model_response.choices[0].message.content = str( # type: ignore
completion_response
)
model_response.created = int(time.time())
model_response.model = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response.choices[0].finish_reason = map_finish_reason(
response_obj.candidates[0].finish_reason.name
)
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
# init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None and (
hasattr(response_obj, "usage_metadata")
and hasattr(response_obj.usage_metadata, "prompt_token_count")
):
prompt_tokens = response_obj.usage_metadata.prompt_token_count
completion_tokens = response_obj.usage_metadata.candidates_token_count
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
# set usage
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming( # noqa: PLR0915
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
messages: list,
print_verbose: Callable,
logging_obj,
request_str: str,
encoding=None,
client_options=None,
instances=None,
vertex_project=None,
vertex_location=None,
safety_settings=None,
**optional_params,
):
"""
Add support for async streaming calls for gemini-pro
"""
response: Any = None
if mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
elif mode == "text":
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom":
from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
stream = optional_params.pop("stream", None)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream:
response = TextStreamer(completion_response)
elif mode == "private":
if instances is None:
raise ValueError("Instances are required for private endpoint")
stream = optional_params.pop("stream", None)
_ = instances[0].pop("stream", None)
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream:
response = TextStreamer(completion_response)
if response is None:
raise ValueError("Unable to generate response")
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
return streamwrapper