Raju2024's picture
Upload 1072 files
e3278e4 verified
# What is this?
## API Handler for calling Vertex AI Partner Models
from enum import Enum
from typing import Callable, Optional, Union
import httpx # type: ignore
import litellm
from litellm import LlmProviders
from litellm.utils import ModelResponse
from ..vertex_llm_base import VertexBase
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
claude = "claude"
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def create_vertex_url(
vertex_location: str,
vertex_project: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex partner models"""
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.claude:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
class VertexAIPartnerModels(VertexBase):
def __init__(self) -> None:
pass
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.codestral.completion.handler import (
CodestralTextCompletion,
)
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
except Exception as e:
raise VertexAIError(
status_code=400,
message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
openai_like_chat_completions = OpenAILikeChatHandler()
codestral_fim_completions = CodestralTextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
if "llama" in model:
partner = VertexPartnerProvider.llama
elif "mistral" in model or "codestral" in model:
partner = VertexPartnerProvider.mistralai
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
elif "claude" in model:
partner = VertexPartnerProvider.claude
default_api_base = create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
partner=partner, # type: ignore
stream=stream,
model=model,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=stream,
auth_header=None,
url=default_api_base,
)
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
return codestral_fim_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=text_completion_model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
)
elif "claude" in model:
if headers is None:
headers = {}
headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{
"anthropic_version": "vertex-2023-10-16",
"is_vertex_request": True,
}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=access_token,
logging_obj=logging_obj,
headers=headers,
timeout=timeout,
client=client,
custom_llm_provider=LlmProviders.VERTEX_AI.value,
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai",
custom_endpoint=True,
)
except Exception as e:
if hasattr(e, "status_code"):
raise e
raise VertexAIError(status_code=500, message=str(e))