|
import json |
|
from typing import Any, Dict, List, Optional |
|
|
|
import httpx |
|
from openai.types.image import Image |
|
|
|
import litellm |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
AsyncHTTPHandler, |
|
HTTPHandler, |
|
get_async_httpx_client, |
|
) |
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM |
|
from litellm.types.utils import ImageResponse |
|
|
|
|
|
class VertexImageGeneration(VertexLLM): |
|
def process_image_generation_response( |
|
self, |
|
json_response: Dict[str, Any], |
|
model_response: ImageResponse, |
|
model: Optional[str] = None, |
|
) -> ImageResponse: |
|
if "predictions" not in json_response: |
|
raise litellm.InternalServerError( |
|
message=f"image generation response does not contain 'predictions', got {json_response}", |
|
llm_provider="vertex_ai", |
|
model=model, |
|
) |
|
|
|
predictions = json_response["predictions"] |
|
response_data: List[Image] = [] |
|
|
|
for prediction in predictions: |
|
bytes_base64_encoded = prediction["bytesBase64Encoded"] |
|
image_object = Image(b64_json=bytes_base64_encoded) |
|
response_data.append(image_object) |
|
|
|
model_response.data = response_data |
|
return model_response |
|
|
|
def image_generation( |
|
self, |
|
prompt: str, |
|
vertex_project: Optional[str], |
|
vertex_location: Optional[str], |
|
vertex_credentials: Optional[str], |
|
model_response: ImageResponse, |
|
logging_obj: Any, |
|
model: Optional[ |
|
str |
|
] = "imagegeneration", |
|
client: Optional[Any] = None, |
|
optional_params: Optional[dict] = None, |
|
timeout: Optional[int] = None, |
|
aimg_generation=False, |
|
) -> ImageResponse: |
|
if aimg_generation is True: |
|
return self.aimage_generation( |
|
prompt=prompt, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
vertex_credentials=vertex_credentials, |
|
model=model, |
|
client=client, |
|
optional_params=optional_params, |
|
timeout=timeout, |
|
logging_obj=logging_obj, |
|
model_response=model_response, |
|
) |
|
|
|
if client is None: |
|
_params = {} |
|
if timeout is not None: |
|
if isinstance(timeout, float) or isinstance(timeout, int): |
|
_httpx_timeout = httpx.Timeout(timeout) |
|
_params["timeout"] = _httpx_timeout |
|
else: |
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) |
|
|
|
sync_handler: HTTPHandler = HTTPHandler(**_params) |
|
else: |
|
sync_handler = client |
|
|
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" |
|
|
|
auth_header, _ = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider="vertex_ai", |
|
) |
|
optional_params = optional_params or { |
|
"sampleCount": 1 |
|
} |
|
|
|
request_data = { |
|
"instances": [{"prompt": prompt}], |
|
"parameters": optional_params, |
|
} |
|
|
|
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" |
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
response = sync_handler.post( |
|
url=url, |
|
headers={ |
|
"Content-Type": "application/json; charset=utf-8", |
|
"Authorization": f"Bearer {auth_header}", |
|
}, |
|
data=json.dumps(request_data), |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
json_response = response.json() |
|
return self.process_image_generation_response( |
|
json_response, model_response, model |
|
) |
|
|
|
async def aimage_generation( |
|
self, |
|
prompt: str, |
|
vertex_project: Optional[str], |
|
vertex_location: Optional[str], |
|
vertex_credentials: Optional[str], |
|
model_response: litellm.ImageResponse, |
|
logging_obj: Any, |
|
model: Optional[ |
|
str |
|
] = "imagegeneration", |
|
client: Optional[AsyncHTTPHandler] = None, |
|
optional_params: Optional[dict] = None, |
|
timeout: Optional[int] = None, |
|
): |
|
response = None |
|
if client is None: |
|
_params = {} |
|
if timeout is not None: |
|
if isinstance(timeout, float) or isinstance(timeout, int): |
|
_httpx_timeout = httpx.Timeout(timeout) |
|
_params["timeout"] = _httpx_timeout |
|
else: |
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) |
|
|
|
self.async_handler = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.VERTEX_AI, |
|
params={"timeout": timeout}, |
|
) |
|
else: |
|
self.async_handler = client |
|
|
|
|
|
|
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" |
|
|
|
""" |
|
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 |
|
curl -X POST \ |
|
-H "Authorization: Bearer $(gcloud auth print-access-token)" \ |
|
-H "Content-Type: application/json; charset=utf-8" \ |
|
-d { |
|
"instances": [ |
|
{ |
|
"prompt": "a cat" |
|
} |
|
], |
|
"parameters": { |
|
"sampleCount": 1 |
|
} |
|
} \ |
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" |
|
""" |
|
auth_header, _ = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider="vertex_ai", |
|
) |
|
optional_params = optional_params or { |
|
"sampleCount": 1 |
|
} |
|
|
|
request_data = { |
|
"instances": [{"prompt": prompt}], |
|
"parameters": optional_params, |
|
} |
|
|
|
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" |
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key=None, |
|
additional_args={ |
|
"complete_input_dict": optional_params, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
response = await self.async_handler.post( |
|
url=url, |
|
headers={ |
|
"Content-Type": "application/json; charset=utf-8", |
|
"Authorization": f"Bearer {auth_header}", |
|
}, |
|
data=json.dumps(request_data), |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception(f"Error: {response.status_code} {response.text}") |
|
|
|
json_response = response.json() |
|
return self.process_image_generation_response( |
|
json_response, model_response, model |
|
) |
|
|
|
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool: |
|
if "predictions" in json_response: |
|
if "bytesBase64Encoded" in json_response["predictions"][0]: |
|
return True |
|
return False |
|
|