|
""" |
|
Azure Batches API Handler |
|
""" |
|
|
|
from typing import Any, Coroutine, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI |
|
from litellm.types.llms.openai import ( |
|
Batch, |
|
CancelBatchRequest, |
|
CreateBatchRequest, |
|
RetrieveBatchRequest, |
|
) |
|
|
|
|
|
class AzureBatchesAPI: |
|
""" |
|
Azure methods to support for batches |
|
- create_batch() |
|
- retrieve_batch() |
|
- cancel_batch() |
|
- list_batch() |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def get_azure_openai_client( |
|
self, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
api_version: Optional[str] = None, |
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, |
|
_is_async: bool = False, |
|
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: |
|
received_args = locals() |
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None |
|
if client is None: |
|
data = {} |
|
for k, v in received_args.items(): |
|
if k == "self" or k == "client" or k == "_is_async": |
|
pass |
|
elif k == "api_base" and v is not None: |
|
data["azure_endpoint"] = v |
|
elif v is not None: |
|
data[k] = v |
|
if "api_version" not in data: |
|
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION |
|
if _is_async is True: |
|
openai_client = AsyncAzureOpenAI(**data) |
|
else: |
|
openai_client = AzureOpenAI(**data) |
|
else: |
|
openai_client = client |
|
|
|
return openai_client |
|
|
|
async def acreate_batch( |
|
self, |
|
create_batch_data: CreateBatchRequest, |
|
azure_client: AsyncAzureOpenAI, |
|
) -> Batch: |
|
response = await azure_client.batches.create(**create_batch_data) |
|
return response |
|
|
|
def create_batch( |
|
self, |
|
_is_async: bool, |
|
create_batch_data: CreateBatchRequest, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
api_version: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, |
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]: |
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( |
|
self.get_azure_openai_client( |
|
api_key=api_key, |
|
api_base=api_base, |
|
timeout=timeout, |
|
api_version=api_version, |
|
max_retries=max_retries, |
|
client=client, |
|
_is_async=_is_async, |
|
) |
|
) |
|
if azure_client is None: |
|
raise ValueError( |
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." |
|
) |
|
|
|
if _is_async is True: |
|
if not isinstance(azure_client, AsyncAzureOpenAI): |
|
raise ValueError( |
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." |
|
) |
|
return self.acreate_batch( |
|
create_batch_data=create_batch_data, azure_client=azure_client |
|
) |
|
response = azure_client.batches.create(**create_batch_data) |
|
return response |
|
|
|
async def aretrieve_batch( |
|
self, |
|
retrieve_batch_data: RetrieveBatchRequest, |
|
client: AsyncAzureOpenAI, |
|
) -> Batch: |
|
response = await client.batches.retrieve(**retrieve_batch_data) |
|
return response |
|
|
|
def retrieve_batch( |
|
self, |
|
_is_async: bool, |
|
retrieve_batch_data: RetrieveBatchRequest, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
api_version: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
client: Optional[AzureOpenAI] = None, |
|
): |
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( |
|
self.get_azure_openai_client( |
|
api_key=api_key, |
|
api_base=api_base, |
|
api_version=api_version, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
client=client, |
|
_is_async=_is_async, |
|
) |
|
) |
|
if azure_client is None: |
|
raise ValueError( |
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." |
|
) |
|
|
|
if _is_async is True: |
|
if not isinstance(azure_client, AsyncAzureOpenAI): |
|
raise ValueError( |
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." |
|
) |
|
return self.aretrieve_batch( |
|
retrieve_batch_data=retrieve_batch_data, client=azure_client |
|
) |
|
response = azure_client.batches.retrieve(**retrieve_batch_data) |
|
return response |
|
|
|
async def acancel_batch( |
|
self, |
|
cancel_batch_data: CancelBatchRequest, |
|
client: AsyncAzureOpenAI, |
|
) -> Batch: |
|
response = await client.batches.cancel(**cancel_batch_data) |
|
return response |
|
|
|
def cancel_batch( |
|
self, |
|
_is_async: bool, |
|
cancel_batch_data: CancelBatchRequest, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
api_version: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
client: Optional[AzureOpenAI] = None, |
|
): |
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( |
|
self.get_azure_openai_client( |
|
api_key=api_key, |
|
api_base=api_base, |
|
api_version=api_version, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
client=client, |
|
_is_async=_is_async, |
|
) |
|
) |
|
if azure_client is None: |
|
raise ValueError( |
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." |
|
) |
|
response = azure_client.batches.cancel(**cancel_batch_data) |
|
return response |
|
|
|
async def alist_batches( |
|
self, |
|
client: AsyncAzureOpenAI, |
|
after: Optional[str] = None, |
|
limit: Optional[int] = None, |
|
): |
|
response = await client.batches.list(after=after, limit=limit) |
|
return response |
|
|
|
def list_batches( |
|
self, |
|
_is_async: bool, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
api_version: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
max_retries: Optional[int], |
|
after: Optional[str] = None, |
|
limit: Optional[int] = None, |
|
client: Optional[AzureOpenAI] = None, |
|
): |
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( |
|
self.get_azure_openai_client( |
|
api_key=api_key, |
|
api_base=api_base, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
api_version=api_version, |
|
client=client, |
|
_is_async=_is_async, |
|
) |
|
) |
|
if azure_client is None: |
|
raise ValueError( |
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." |
|
) |
|
|
|
if _is_async is True: |
|
if not isinstance(azure_client, AsyncAzureOpenAI): |
|
raise ValueError( |
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." |
|
) |
|
return self.alist_batches( |
|
client=azure_client, after=after, limit=limit |
|
) |
|
response = azure_client.batches.list(after=after, limit=limit) |
|
return response |
|
|