Spaces:
Sleeping
Sleeping
""" | |
Azure Batches API Handler | |
""" | |
from typing import Any, Coroutine, Optional, Union, cast | |
import httpx | |
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI | |
from litellm.types.llms.openai import ( | |
Batch, | |
CancelBatchRequest, | |
CreateBatchRequest, | |
RetrieveBatchRequest, | |
) | |
from litellm.types.utils import LiteLLMBatch | |
from ..common_utils import BaseAzureLLM | |
class AzureBatchesAPI(BaseAzureLLM): | |
""" | |
Azure methods to support for batches | |
- create_batch() | |
- retrieve_batch() | |
- cancel_batch() | |
- list_batch() | |
""" | |
def __init__(self) -> None: | |
super().__init__() | |
async def acreate_batch( | |
self, | |
create_batch_data: CreateBatchRequest, | |
azure_client: AsyncAzureOpenAI, | |
) -> LiteLLMBatch: | |
response = await azure_client.batches.create(**create_batch_data) | |
return LiteLLMBatch(**response.model_dump()) | |
def create_batch( | |
self, | |
_is_async: bool, | |
create_batch_data: CreateBatchRequest, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, | |
litellm_params: Optional[dict] = None, | |
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: | |
azure_client: Optional[ | |
Union[AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_azure_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
client=client, | |
_is_async=_is_async, | |
litellm_params=litellm_params or {}, | |
) | |
if azure_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(azure_client, AsyncAzureOpenAI): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.acreate_batch( # type: ignore | |
create_batch_data=create_batch_data, azure_client=azure_client | |
) | |
response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data) | |
return LiteLLMBatch(**response.model_dump()) | |
async def aretrieve_batch( | |
self, | |
retrieve_batch_data: RetrieveBatchRequest, | |
client: AsyncAzureOpenAI, | |
) -> LiteLLMBatch: | |
response = await client.batches.retrieve(**retrieve_batch_data) | |
return LiteLLMBatch(**response.model_dump()) | |
def retrieve_batch( | |
self, | |
_is_async: bool, | |
retrieve_batch_data: RetrieveBatchRequest, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
): | |
azure_client: Optional[ | |
Union[AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_azure_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
client=client, | |
_is_async=_is_async, | |
litellm_params=litellm_params or {}, | |
) | |
if azure_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(azure_client, AsyncAzureOpenAI): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.aretrieve_batch( # type: ignore | |
retrieve_batch_data=retrieve_batch_data, client=azure_client | |
) | |
response = cast(AzureOpenAI, azure_client).batches.retrieve( | |
**retrieve_batch_data | |
) | |
return LiteLLMBatch(**response.model_dump()) | |
async def acancel_batch( | |
self, | |
cancel_batch_data: CancelBatchRequest, | |
client: AsyncAzureOpenAI, | |
) -> Batch: | |
response = await client.batches.cancel(**cancel_batch_data) | |
return response | |
def cancel_batch( | |
self, | |
_is_async: bool, | |
cancel_batch_data: CancelBatchRequest, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
): | |
azure_client: Optional[ | |
Union[AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_azure_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
client=client, | |
_is_async=_is_async, | |
litellm_params=litellm_params or {}, | |
) | |
if azure_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
response = azure_client.batches.cancel(**cancel_batch_data) | |
return response | |
async def alist_batches( | |
self, | |
client: AsyncAzureOpenAI, | |
after: Optional[str] = None, | |
limit: Optional[int] = None, | |
): | |
response = await client.batches.list(after=after, limit=limit) # type: ignore | |
return response | |
def list_batches( | |
self, | |
_is_async: bool, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
after: Optional[str] = None, | |
limit: Optional[int] = None, | |
client: Optional[AzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
): | |
azure_client: Optional[ | |
Union[AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_azure_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
client=client, | |
_is_async=_is_async, | |
litellm_params=litellm_params or {}, | |
) | |
if azure_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(azure_client, AsyncAzureOpenAI): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.alist_batches( # type: ignore | |
client=azure_client, after=after, limit=limit | |
) | |
response = azure_client.batches.list(after=after, limit=limit) # type: ignore | |
return response | |