|
import json |
|
import os |
|
import time |
|
from typing import Any, Callable, Optional, cast |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason |
|
from litellm.llms.bedrock.common_utils import ModelResponseIterator |
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS |
|
from litellm.types.llms.vertex_ai import * |
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage |
|
|
|
|
|
class VertexAIError(Exception): |
|
def __init__(self, status_code, message): |
|
self.status_code = status_code |
|
self.message = message |
|
self.request = httpx.Request( |
|
method="POST", url=" https://cloud.google.com/vertex-ai/" |
|
) |
|
self.response = httpx.Response(status_code=status_code, request=self.request) |
|
super().__init__( |
|
self.message |
|
) |
|
|
|
|
|
class TextStreamer: |
|
""" |
|
Fake streaming iterator for Vertex AI Model Garden calls |
|
""" |
|
|
|
def __init__(self, text): |
|
self.text = text.split() |
|
self.index = 0 |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
if self.index < len(self.text): |
|
result = self.text[self.index] |
|
self.index += 1 |
|
return result |
|
else: |
|
raise StopIteration |
|
|
|
def __aiter__(self): |
|
return self |
|
|
|
async def __anext__(self): |
|
if self.index < len(self.text): |
|
result = self.text[self.index] |
|
self.index += 1 |
|
return result |
|
else: |
|
raise StopAsyncIteration |
|
|
|
|
|
def _get_client_cache_key( |
|
model: str, vertex_project: Optional[str], vertex_location: Optional[str] |
|
): |
|
_cache_key = f"{model}-{vertex_project}-{vertex_location}" |
|
return _cache_key |
|
|
|
|
|
def _get_client_from_cache(client_cache_key: str): |
|
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key) |
|
|
|
|
|
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any): |
|
litellm.in_memory_llm_clients_cache.set_cache( |
|
key=client_cache_key, |
|
value=vertex_llm_model, |
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, |
|
) |
|
|
|
|
|
def completion( |
|
model: str, |
|
messages: list, |
|
model_response: ModelResponse, |
|
print_verbose: Callable, |
|
encoding, |
|
logging_obj, |
|
optional_params: dict, |
|
vertex_project=None, |
|
vertex_location=None, |
|
vertex_credentials=None, |
|
litellm_params=None, |
|
logger_fn=None, |
|
acompletion: bool = False, |
|
): |
|
""" |
|
NON-GEMINI/ANTHROPIC CALLS. |
|
|
|
This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN |
|
|
|
For Vertex AI Anthropic: `vertex_anthropic.py` |
|
For Gemini: `vertex_httpx.py` |
|
""" |
|
try: |
|
import vertexai |
|
except Exception: |
|
raise VertexAIError( |
|
status_code=400, |
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM", |
|
) |
|
|
|
if not ( |
|
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") |
|
): |
|
raise VertexAIError( |
|
status_code=400, |
|
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", |
|
) |
|
try: |
|
import google.auth |
|
from google.cloud import aiplatform |
|
from google.cloud.aiplatform_v1beta1.types import ( |
|
content as gapic_content_types, |
|
) |
|
from google.protobuf import json_format |
|
from google.protobuf.struct_pb2 import Value |
|
from vertexai.language_models import CodeGenerationModel, TextGenerationModel |
|
from vertexai.preview.generative_models import GenerativeModel |
|
from vertexai.preview.language_models import ChatModel, CodeChatModel |
|
|
|
|
|
print_verbose( |
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" |
|
) |
|
|
|
_cache_key = _get_client_cache_key( |
|
model=model, vertex_project=vertex_project, vertex_location=vertex_location |
|
) |
|
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key) |
|
|
|
if _vertex_llm_model_object is None: |
|
from google.auth.credentials import Credentials |
|
|
|
if vertex_credentials is not None and isinstance(vertex_credentials, str): |
|
import google.oauth2.service_account |
|
|
|
json_obj = json.loads(vertex_credentials) |
|
|
|
creds = ( |
|
google.oauth2.service_account.Credentials.from_service_account_info( |
|
json_obj, |
|
scopes=["https://www.googleapis.com/auth/cloud-platform"], |
|
) |
|
) |
|
else: |
|
creds, _ = google.auth.default(quota_project_id=vertex_project) |
|
print_verbose( |
|
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" |
|
) |
|
vertexai.init( |
|
project=vertex_project, |
|
location=vertex_location, |
|
credentials=cast(Credentials, creds), |
|
) |
|
|
|
|
|
config = litellm.VertexAIConfig.get_config() |
|
for k, v in config.items(): |
|
if k not in optional_params: |
|
optional_params[k] = v |
|
|
|
|
|
safety_settings = None |
|
if "safety_settings" in optional_params: |
|
safety_settings = optional_params.pop("safety_settings") |
|
if not isinstance(safety_settings, list): |
|
raise ValueError("safety_settings must be a list") |
|
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): |
|
raise ValueError("safety_settings must be a list of dicts") |
|
safety_settings = [ |
|
gapic_content_types.SafetySetting(x) for x in safety_settings |
|
] |
|
|
|
|
|
|
|
prompt = " ".join( |
|
[ |
|
message.get("content") |
|
for message in messages |
|
if isinstance(message.get("content", None), str) |
|
] |
|
) |
|
|
|
mode = "" |
|
|
|
request_str = "" |
|
response_obj = None |
|
instances = None |
|
client_options = { |
|
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" |
|
} |
|
fake_stream = False |
|
if ( |
|
model in litellm.vertex_language_models |
|
or model in litellm.vertex_vision_models |
|
): |
|
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model) |
|
mode = "vision" |
|
request_str += f"llm_model = GenerativeModel({model})\n" |
|
elif model in litellm.vertex_chat_models: |
|
llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model) |
|
mode = "chat" |
|
request_str += f"llm_model = ChatModel.from_pretrained({model})\n" |
|
elif model in litellm.vertex_text_models: |
|
llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained( |
|
model |
|
) |
|
mode = "text" |
|
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" |
|
elif model in litellm.vertex_code_text_models: |
|
llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained( |
|
model |
|
) |
|
mode = "text" |
|
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" |
|
fake_stream = True |
|
elif model in litellm.vertex_code_chat_models: |
|
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model) |
|
mode = "chat" |
|
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" |
|
elif model == "private": |
|
mode = "private" |
|
model = optional_params.pop("model_id", None) |
|
|
|
instances = [optional_params.copy()] |
|
instances[0]["prompt"] = prompt |
|
llm_model = aiplatform.PrivateEndpoint( |
|
endpoint_name=model, |
|
project=vertex_project, |
|
location=vertex_location, |
|
) |
|
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" |
|
else: |
|
mode = "custom" |
|
|
|
instances = [optional_params.copy()] |
|
instances[0]["prompt"] = prompt |
|
instances = [ |
|
json_format.ParseDict(instance_dict, Value()) |
|
for instance_dict in instances |
|
] |
|
|
|
llm_model = None |
|
|
|
|
|
if acompletion is True: |
|
data = { |
|
"llm_model": llm_model, |
|
"mode": mode, |
|
"prompt": prompt, |
|
"logging_obj": logging_obj, |
|
"request_str": request_str, |
|
"model": model, |
|
"model_response": model_response, |
|
"encoding": encoding, |
|
"messages": messages, |
|
"print_verbose": print_verbose, |
|
"client_options": client_options, |
|
"instances": instances, |
|
"vertex_location": vertex_location, |
|
"vertex_project": vertex_project, |
|
"safety_settings": safety_settings, |
|
**optional_params, |
|
} |
|
if optional_params.get("stream", False) is True: |
|
|
|
return async_streaming(**data) |
|
|
|
return async_completion(**data) |
|
|
|
completion_response = None |
|
|
|
stream = optional_params.pop( |
|
"stream", None |
|
) |
|
if mode == "chat": |
|
chat = llm_model.start_chat() |
|
request_str += "chat = llm_model.start_chat()\n" |
|
|
|
if fake_stream is not True and stream is True: |
|
|
|
|
|
|
|
optional_params.pop( |
|
"stream", None |
|
) |
|
|
|
request_str += ( |
|
f"chat.send_message_streaming({prompt}, **{optional_params})\n" |
|
) |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
model_response = chat.send_message_streaming(prompt, **optional_params) |
|
|
|
return model_response |
|
|
|
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
completion_response = chat.send_message(prompt, **optional_params).text |
|
elif mode == "text": |
|
|
|
if fake_stream is not True and stream is True: |
|
request_str += ( |
|
f"llm_model.predict_streaming({prompt}, **{optional_params})\n" |
|
) |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
model_response = llm_model.predict_streaming(prompt, **optional_params) |
|
|
|
return model_response |
|
|
|
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
completion_response = llm_model.predict(prompt, **optional_params).text |
|
elif mode == "custom": |
|
""" |
|
Vertex AI Model Garden |
|
""" |
|
|
|
if vertex_project is None or vertex_location is None: |
|
raise ValueError( |
|
"Vertex project and location are required for custom endpoint" |
|
) |
|
|
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
llm_model = aiplatform.gapic.PredictionServiceClient( |
|
client_options=client_options |
|
) |
|
request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n" |
|
endpoint_path = llm_model.endpoint_path( |
|
project=vertex_project, location=vertex_location, endpoint=model |
|
) |
|
request_str += ( |
|
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" |
|
) |
|
response = llm_model.predict( |
|
endpoint=endpoint_path, instances=instances |
|
).predictions |
|
|
|
completion_response = response[0] |
|
if ( |
|
isinstance(completion_response, str) |
|
and "\nOutput:\n" in completion_response |
|
): |
|
completion_response = completion_response.split("\nOutput:\n", 1)[1] |
|
if stream is True: |
|
response = TextStreamer(completion_response) |
|
return response |
|
elif mode == "private": |
|
""" |
|
Vertex AI Model Garden deployed on private endpoint |
|
""" |
|
if instances is None: |
|
raise ValueError("instances are required for private endpoint") |
|
if llm_model is None: |
|
raise ValueError("Unable to pick client for private endpoint") |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
request_str += f"llm_model.predict(instances={instances})\n" |
|
response = llm_model.predict(instances=instances).predictions |
|
|
|
completion_response = response[0] |
|
if ( |
|
isinstance(completion_response, str) |
|
and "\nOutput:\n" in completion_response |
|
): |
|
completion_response = completion_response.split("\nOutput:\n", 1)[1] |
|
if stream is True: |
|
response = TextStreamer(completion_response) |
|
return response |
|
|
|
|
|
logging_obj.post_call( |
|
input=prompt, api_key=None, original_response=completion_response |
|
) |
|
|
|
|
|
if isinstance(completion_response, litellm.Message): |
|
model_response.choices[0].message = completion_response |
|
elif len(str(completion_response)) > 0: |
|
model_response.choices[0].message.content = str(completion_response) |
|
model_response.created = int(time.time()) |
|
model_response.model = model |
|
|
|
if model in litellm.vertex_language_models and response_obj is not None: |
|
model_response.choices[0].finish_reason = map_finish_reason( |
|
response_obj.candidates[0].finish_reason.name |
|
) |
|
usage = Usage( |
|
prompt_tokens=response_obj.usage_metadata.prompt_token_count, |
|
completion_tokens=response_obj.usage_metadata.candidates_token_count, |
|
total_tokens=response_obj.usage_metadata.total_token_count, |
|
) |
|
else: |
|
|
|
|
|
prompt_tokens, completion_tokens, _ = 0, 0, 0 |
|
if response_obj is not None: |
|
if hasattr(response_obj, "usage_metadata") and hasattr( |
|
response_obj.usage_metadata, "prompt_token_count" |
|
): |
|
prompt_tokens = response_obj.usage_metadata.prompt_token_count |
|
completion_tokens = ( |
|
response_obj.usage_metadata.candidates_token_count |
|
) |
|
else: |
|
prompt_tokens = len(encoding.encode(prompt)) |
|
completion_tokens = len( |
|
encoding.encode( |
|
model_response["choices"][0]["message"].get("content", "") |
|
) |
|
) |
|
|
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
setattr(model_response, "usage", usage) |
|
|
|
if fake_stream is True and stream is True: |
|
return ModelResponseIterator(model_response) |
|
return model_response |
|
except Exception as e: |
|
if isinstance(e, VertexAIError): |
|
raise e |
|
raise litellm.APIConnectionError( |
|
message=str(e), llm_provider="vertex_ai", model=model |
|
) |
|
|
|
|
|
async def async_completion( |
|
llm_model, |
|
mode: str, |
|
prompt: str, |
|
model: str, |
|
messages: list, |
|
model_response: ModelResponse, |
|
request_str: str, |
|
print_verbose: Callable, |
|
logging_obj, |
|
encoding, |
|
client_options=None, |
|
instances=None, |
|
vertex_project=None, |
|
vertex_location=None, |
|
safety_settings=None, |
|
**optional_params, |
|
): |
|
""" |
|
Add support for acompletion calls for gemini-pro |
|
""" |
|
try: |
|
|
|
response_obj = None |
|
completion_response = None |
|
if mode == "chat": |
|
|
|
chat = llm_model.start_chat() |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
response_obj = await chat.send_message_async(prompt, **optional_params) |
|
completion_response = response_obj.text |
|
elif mode == "text": |
|
|
|
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
response_obj = await llm_model.predict_async(prompt, **optional_params) |
|
completion_response = response_obj.text |
|
elif mode == "custom": |
|
""" |
|
Vertex AI Model Garden |
|
""" |
|
from google.cloud import aiplatform |
|
|
|
if vertex_project is None or vertex_location is None: |
|
raise ValueError( |
|
"Vertex project and location are required for custom endpoint" |
|
) |
|
|
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
llm_model = aiplatform.gapic.PredictionServiceAsyncClient( |
|
client_options=client_options |
|
) |
|
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n" |
|
endpoint_path = llm_model.endpoint_path( |
|
project=vertex_project, location=vertex_location, endpoint=model |
|
) |
|
request_str += ( |
|
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" |
|
) |
|
response_obj = await llm_model.predict( |
|
endpoint=endpoint_path, |
|
instances=instances, |
|
) |
|
response = response_obj.predictions |
|
completion_response = response[0] |
|
if ( |
|
isinstance(completion_response, str) |
|
and "\nOutput:\n" in completion_response |
|
): |
|
completion_response = completion_response.split("\nOutput:\n", 1)[1] |
|
|
|
elif mode == "private": |
|
request_str += f"llm_model.predict_async(instances={instances})\n" |
|
response_obj = await llm_model.predict_async( |
|
instances=instances, |
|
) |
|
|
|
response = response_obj.predictions |
|
completion_response = response[0] |
|
if ( |
|
isinstance(completion_response, str) |
|
and "\nOutput:\n" in completion_response |
|
): |
|
completion_response = completion_response.split("\nOutput:\n", 1)[1] |
|
|
|
|
|
logging_obj.post_call( |
|
input=prompt, api_key=None, original_response=completion_response |
|
) |
|
|
|
|
|
if isinstance(completion_response, litellm.Message): |
|
model_response.choices[0].message = completion_response |
|
elif len(str(completion_response)) > 0: |
|
model_response.choices[0].message.content = str( |
|
completion_response |
|
) |
|
model_response.created = int(time.time()) |
|
model_response.model = model |
|
|
|
if model in litellm.vertex_language_models and response_obj is not None: |
|
model_response.choices[0].finish_reason = map_finish_reason( |
|
response_obj.candidates[0].finish_reason.name |
|
) |
|
usage = Usage( |
|
prompt_tokens=response_obj.usage_metadata.prompt_token_count, |
|
completion_tokens=response_obj.usage_metadata.candidates_token_count, |
|
total_tokens=response_obj.usage_metadata.total_token_count, |
|
) |
|
else: |
|
|
|
|
|
prompt_tokens, completion_tokens, _ = 0, 0, 0 |
|
if response_obj is not None and ( |
|
hasattr(response_obj, "usage_metadata") |
|
and hasattr(response_obj.usage_metadata, "prompt_token_count") |
|
): |
|
prompt_tokens = response_obj.usage_metadata.prompt_token_count |
|
completion_tokens = response_obj.usage_metadata.candidates_token_count |
|
else: |
|
prompt_tokens = len(encoding.encode(prompt)) |
|
completion_tokens = len( |
|
encoding.encode( |
|
model_response["choices"][0]["message"].get("content", "") |
|
) |
|
) |
|
|
|
|
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
setattr(model_response, "usage", usage) |
|
return model_response |
|
except Exception as e: |
|
raise VertexAIError(status_code=500, message=str(e)) |
|
|
|
|
|
async def async_streaming( |
|
llm_model, |
|
mode: str, |
|
prompt: str, |
|
model: str, |
|
model_response: ModelResponse, |
|
messages: list, |
|
print_verbose: Callable, |
|
logging_obj, |
|
request_str: str, |
|
encoding=None, |
|
client_options=None, |
|
instances=None, |
|
vertex_project=None, |
|
vertex_location=None, |
|
safety_settings=None, |
|
**optional_params, |
|
): |
|
""" |
|
Add support for async streaming calls for gemini-pro |
|
""" |
|
response: Any = None |
|
if mode == "chat": |
|
chat = llm_model.start_chat() |
|
optional_params.pop( |
|
"stream", None |
|
) |
|
request_str += ( |
|
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" |
|
) |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
response = chat.send_message_streaming_async(prompt, **optional_params) |
|
|
|
elif mode == "text": |
|
optional_params.pop( |
|
"stream", None |
|
) |
|
request_str += ( |
|
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" |
|
) |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
response = llm_model.predict_streaming_async(prompt, **optional_params) |
|
elif mode == "custom": |
|
from google.cloud import aiplatform |
|
|
|
if vertex_project is None or vertex_location is None: |
|
raise ValueError( |
|
"Vertex project and location are required for custom endpoint" |
|
) |
|
|
|
stream = optional_params.pop("stream", None) |
|
|
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
llm_model = aiplatform.gapic.PredictionServiceAsyncClient( |
|
client_options=client_options |
|
) |
|
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n" |
|
endpoint_path = llm_model.endpoint_path( |
|
project=vertex_project, location=vertex_location, endpoint=model |
|
) |
|
request_str += ( |
|
f"client.predict(endpoint={endpoint_path}, instances={instances})\n" |
|
) |
|
response_obj = await llm_model.predict( |
|
endpoint=endpoint_path, |
|
instances=instances, |
|
) |
|
|
|
response = response_obj.predictions |
|
completion_response = response[0] |
|
if ( |
|
isinstance(completion_response, str) |
|
and "\nOutput:\n" in completion_response |
|
): |
|
completion_response = completion_response.split("\nOutput:\n", 1)[1] |
|
if stream: |
|
response = TextStreamer(completion_response) |
|
|
|
elif mode == "private": |
|
if instances is None: |
|
raise ValueError("Instances are required for private endpoint") |
|
stream = optional_params.pop("stream", None) |
|
_ = instances[0].pop("stream", None) |
|
request_str += f"llm_model.predict_async(instances={instances})\n" |
|
response_obj = await llm_model.predict_async( |
|
instances=instances, |
|
) |
|
response = response_obj.predictions |
|
completion_response = response[0] |
|
if ( |
|
isinstance(completion_response, str) |
|
and "\nOutput:\n" in completion_response |
|
): |
|
completion_response = completion_response.split("\nOutput:\n", 1)[1] |
|
if stream: |
|
response = TextStreamer(completion_response) |
|
|
|
if response is None: |
|
raise ValueError("Unable to generate response") |
|
|
|
logging_obj.post_call(input=prompt, api_key=None, original_response=response) |
|
|
|
streamwrapper = CustomStreamWrapper( |
|
completion_stream=response, |
|
model=model, |
|
custom_llm_provider="vertex_ai", |
|
logging_obj=logging_obj, |
|
) |
|
|
|
return streamwrapper |
|
|