Spaces:
Sleeping
Sleeping
from typing import Any, Coroutine, Optional, Union | |
import httpx | |
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI | |
from openai.types.fine_tuning import FineTuningJob | |
from litellm._logging import verbose_logger | |
class OpenAIFineTuningAPI: | |
""" | |
OpenAI methods to support for batches | |
""" | |
def __init__(self) -> None: | |
super().__init__() | |
def get_openai_client( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
organization: Optional[str], | |
client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = None, | |
_is_async: bool = False, | |
api_version: Optional[str] = None, | |
litellm_params: Optional[dict] = None, | |
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: | |
received_args = locals() | |
openai_client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = None | |
if client is None: | |
data = {} | |
for k, v in received_args.items(): | |
if k == "self" or k == "client" or k == "_is_async": | |
pass | |
elif k == "api_base" and v is not None: | |
data["base_url"] = v | |
elif v is not None: | |
data[k] = v | |
if _is_async is True: | |
openai_client = AsyncOpenAI(**data) | |
else: | |
openai_client = OpenAI(**data) # type: ignore | |
else: | |
openai_client = client | |
return openai_client | |
async def acreate_fine_tuning_job( | |
self, | |
create_fine_tuning_job_data: dict, | |
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
) -> FineTuningJob: | |
response = await openai_client.fine_tuning.jobs.create( | |
**create_fine_tuning_job_data | |
) | |
return response | |
def create_fine_tuning_job( | |
self, | |
_is_async: bool, | |
create_fine_tuning_job_data: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
organization: Optional[str], | |
client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = None, | |
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: | |
openai_client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
timeout=timeout, | |
max_retries=max_retries, | |
organization=organization, | |
client=client, | |
_is_async=_is_async, | |
api_version=api_version, | |
) | |
if openai_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.acreate_fine_tuning_job( # type: ignore | |
create_fine_tuning_job_data=create_fine_tuning_job_data, | |
openai_client=openai_client, | |
) | |
verbose_logger.debug( | |
"creating fine tuning job, args= %s", create_fine_tuning_job_data | |
) | |
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) | |
return response | |
async def acancel_fine_tuning_job( | |
self, | |
fine_tuning_job_id: str, | |
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
) -> FineTuningJob: | |
response = await openai_client.fine_tuning.jobs.cancel( | |
fine_tuning_job_id=fine_tuning_job_id | |
) | |
return response | |
def cancel_fine_tuning_job( | |
self, | |
_is_async: bool, | |
fine_tuning_job_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
organization: Optional[str], | |
client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = None, | |
): | |
openai_client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
timeout=timeout, | |
max_retries=max_retries, | |
organization=organization, | |
client=client, | |
_is_async=_is_async, | |
api_version=api_version, | |
) | |
if openai_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.acancel_fine_tuning_job( # type: ignore | |
fine_tuning_job_id=fine_tuning_job_id, | |
openai_client=openai_client, | |
) | |
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id) | |
response = openai_client.fine_tuning.jobs.cancel( | |
fine_tuning_job_id=fine_tuning_job_id | |
) | |
return response | |
async def alist_fine_tuning_jobs( | |
self, | |
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
after: Optional[str] = None, | |
limit: Optional[int] = None, | |
): | |
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore | |
return response | |
def list_fine_tuning_jobs( | |
self, | |
_is_async: bool, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
organization: Optional[str], | |
client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = None, | |
after: Optional[str] = None, | |
limit: Optional[int] = None, | |
): | |
openai_client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
timeout=timeout, | |
max_retries=max_retries, | |
organization=organization, | |
client=client, | |
_is_async=_is_async, | |
api_version=api_version, | |
) | |
if openai_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.alist_fine_tuning_jobs( # type: ignore | |
after=after, | |
limit=limit, | |
openai_client=openai_client, | |
) | |
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) | |
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore | |
return response | |
async def aretrieve_fine_tuning_job( | |
self, | |
fine_tuning_job_id: str, | |
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
) -> FineTuningJob: | |
response = await openai_client.fine_tuning.jobs.retrieve( | |
fine_tuning_job_id=fine_tuning_job_id | |
) | |
return response | |
def retrieve_fine_tuning_job( | |
self, | |
_is_async: bool, | |
fine_tuning_job_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
organization: Optional[str], | |
client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = None, | |
): | |
openai_client: Optional[ | |
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
] = self.get_openai_client( | |
api_key=api_key, | |
api_base=api_base, | |
timeout=timeout, | |
max_retries=max_retries, | |
organization=organization, | |
client=client, | |
_is_async=_is_async, | |
api_version=api_version, | |
) | |
if openai_client is None: | |
raise ValueError( | |
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
) | |
if _is_async is True: | |
if not isinstance(openai_client, AsyncOpenAI): | |
raise ValueError( | |
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
) | |
return self.aretrieve_fine_tuning_job( # type: ignore | |
fine_tuning_job_id=fine_tuning_job_id, | |
openai_client=openai_client, | |
) | |
verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id) | |
response = openai_client.fine_tuning.jobs.retrieve( | |
fine_tuning_job_id=fine_tuning_job_id | |
) | |
return response | |