Spaces:
Sleeping
Sleeping
from typing import List, Optional, Union | |
from openai import OpenAI | |
import litellm | |
from litellm.llms.custom_httpx.http_handler import ( | |
AsyncHTTPHandler, | |
HTTPHandler, | |
get_async_httpx_client, | |
) | |
from litellm.llms.openai.openai import OpenAIChatCompletion | |
from litellm.types.llms.azure_ai import ImageEmbeddingRequest | |
from litellm.types.utils import EmbeddingResponse | |
from litellm.utils import convert_to_model_response_object | |
from .cohere_transformation import AzureAICohereConfig | |
class AzureAIEmbedding(OpenAIChatCompletion): | |
def _process_response( | |
self, | |
image_embedding_responses: Optional[List], | |
text_embedding_responses: Optional[List], | |
image_embeddings_idx: List[int], | |
model_response: EmbeddingResponse, | |
input: List, | |
): | |
combined_responses = [] | |
if ( | |
image_embedding_responses is not None | |
and text_embedding_responses is not None | |
): | |
# Combine and order the results | |
text_idx = 0 | |
image_idx = 0 | |
for idx in range(len(input)): | |
if idx in image_embeddings_idx: | |
combined_responses.append(image_embedding_responses[image_idx]) | |
image_idx += 1 | |
else: | |
combined_responses.append(text_embedding_responses[text_idx]) | |
text_idx += 1 | |
model_response.data = combined_responses | |
elif image_embedding_responses is not None: | |
model_response.data = image_embedding_responses | |
elif text_embedding_responses is not None: | |
model_response.data = text_embedding_responses | |
response = AzureAICohereConfig()._transform_response(response=model_response) # type: ignore | |
return response | |
async def async_image_embedding( | |
self, | |
model: str, | |
data: ImageEmbeddingRequest, | |
timeout: float, | |
logging_obj, | |
model_response: litellm.EmbeddingResponse, | |
optional_params: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
) -> EmbeddingResponse: | |
if client is None or not isinstance(client, AsyncHTTPHandler): | |
client = get_async_httpx_client( | |
llm_provider=litellm.LlmProviders.AZURE_AI, | |
params={"timeout": timeout}, | |
) | |
url = "{}/images/embeddings".format(api_base) | |
response = await client.post( | |
url=url, | |
json=data, # type: ignore | |
headers={"Authorization": "Bearer {}".format(api_key)}, | |
) | |
embedding_response = response.json() | |
embedding_headers = dict(response.headers) | |
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore | |
response_object=embedding_response, | |
model_response_object=model_response, | |
response_type="embedding", | |
stream=False, | |
_response_headers=embedding_headers, | |
) | |
return returned_response | |
def image_embedding( | |
self, | |
model: str, | |
data: ImageEmbeddingRequest, | |
timeout: float, | |
logging_obj, | |
model_response: EmbeddingResponse, | |
optional_params: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
): | |
if api_base is None: | |
raise ValueError( | |
"api_base is None. Please set AZURE_AI_API_BASE or dynamically via `api_base` param, to make the request." | |
) | |
if api_key is None: | |
raise ValueError( | |
"api_key is None. Please set AZURE_AI_API_KEY or dynamically via `api_key` param, to make the request." | |
) | |
if client is None or not isinstance(client, HTTPHandler): | |
client = HTTPHandler(timeout=timeout, concurrent_limit=1) | |
url = "{}/images/embeddings".format(api_base) | |
response = client.post( | |
url=url, | |
json=data, # type: ignore | |
headers={"Authorization": "Bearer {}".format(api_key)}, | |
) | |
embedding_response = response.json() | |
embedding_headers = dict(response.headers) | |
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore | |
response_object=embedding_response, | |
model_response_object=model_response, | |
response_type="embedding", | |
stream=False, | |
_response_headers=embedding_headers, | |
) | |
return returned_response | |
async def async_embedding( | |
self, | |
model: str, | |
input: List, | |
timeout: float, | |
logging_obj, | |
model_response: litellm.EmbeddingResponse, | |
optional_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
) -> EmbeddingResponse: | |
( | |
image_embeddings_request, | |
v1_embeddings_request, | |
image_embeddings_idx, | |
) = AzureAICohereConfig()._transform_request( | |
input=input, optional_params=optional_params, model=model | |
) | |
image_embedding_responses: Optional[List] = None | |
text_embedding_responses: Optional[List] = None | |
if image_embeddings_request["input"]: | |
image_response = await self.async_image_embedding( | |
model=model, | |
data=image_embeddings_request, | |
timeout=timeout, | |
logging_obj=logging_obj, | |
model_response=model_response, | |
optional_params=optional_params, | |
api_key=api_key, | |
api_base=api_base, | |
client=client, | |
) | |
image_embedding_responses = image_response.data | |
if image_embedding_responses is None: | |
raise Exception("/image/embeddings route returned None Embeddings.") | |
if v1_embeddings_request["input"]: | |
response: EmbeddingResponse = await super().embedding( # type: ignore | |
model=model, | |
input=input, | |
timeout=timeout, | |
logging_obj=logging_obj, | |
model_response=model_response, | |
optional_params=optional_params, | |
api_key=api_key, | |
api_base=api_base, | |
client=client, | |
aembedding=True, | |
) | |
text_embedding_responses = response.data | |
if text_embedding_responses is None: | |
raise Exception("/v1/embeddings route returned None Embeddings.") | |
return self._process_response( | |
image_embedding_responses=image_embedding_responses, | |
text_embedding_responses=text_embedding_responses, | |
image_embeddings_idx=image_embeddings_idx, | |
model_response=model_response, | |
input=input, | |
) | |
def embedding( | |
self, | |
model: str, | |
input: List, | |
timeout: float, | |
logging_obj, | |
model_response: EmbeddingResponse, | |
optional_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
aembedding=None, | |
max_retries: Optional[int] = None, | |
) -> EmbeddingResponse: | |
""" | |
- Separate image url from text | |
-> route image url call to `/image/embeddings` | |
-> route text call to `/v1/embeddings` (OpenAI route) | |
assemble result in-order, and return | |
""" | |
if aembedding is True: | |
return self.async_embedding( # type: ignore | |
model, | |
input, | |
timeout, | |
logging_obj, | |
model_response, | |
optional_params, | |
api_key, | |
api_base, | |
client, | |
) | |
( | |
image_embeddings_request, | |
v1_embeddings_request, | |
image_embeddings_idx, | |
) = AzureAICohereConfig()._transform_request( | |
input=input, optional_params=optional_params, model=model | |
) | |
image_embedding_responses: Optional[List] = None | |
text_embedding_responses: Optional[List] = None | |
if image_embeddings_request["input"]: | |
image_response = self.image_embedding( | |
model=model, | |
data=image_embeddings_request, | |
timeout=timeout, | |
logging_obj=logging_obj, | |
model_response=model_response, | |
optional_params=optional_params, | |
api_key=api_key, | |
api_base=api_base, | |
client=client, | |
) | |
image_embedding_responses = image_response.data | |
if image_embedding_responses is None: | |
raise Exception("/image/embeddings route returned None Embeddings.") | |
if v1_embeddings_request["input"]: | |
response: EmbeddingResponse = super().embedding( # type: ignore | |
model, | |
input, | |
timeout, | |
logging_obj, | |
model_response, | |
optional_params, | |
api_key, | |
api_base, | |
client=( | |
client | |
if client is not None and isinstance(client, OpenAI) | |
else None | |
), | |
aembedding=aembedding, | |
) | |
text_embedding_responses = response.data | |
if text_embedding_responses is None: | |
raise Exception("/v1/embeddings route returned None Embeddings.") | |
return self._process_response( | |
image_embedding_responses=image_embedding_responses, | |
text_embedding_responses=text_embedding_responses, | |
image_embeddings_idx=image_embeddings_idx, | |
model_response=model_response, | |
input=input, | |
) | |