Spaces:
Sleeping
Sleeping
from typing import Any, Coroutine, Dict, Iterable, Literal, Optional, Union | |
import httpx | |
from openai import AsyncAzureOpenAI, AzureOpenAI | |
from typing_extensions import overload | |
from ...types.llms.openai import ( | |
Assistant, | |
AssistantEventHandler, | |
AssistantStreamManager, | |
AssistantToolParam, | |
AsyncAssistantEventHandler, | |
AsyncAssistantStreamManager, | |
AsyncCursorPage, | |
OpenAICreateThreadParamsMessage, | |
OpenAIMessage, | |
Run, | |
SyncCursorPage, | |
Thread, | |
) | |
from .common_utils import BaseAzureLLM | |
class AzureAssistantsAPI(BaseAzureLLM): | |
def __init__(self) -> None: | |
super().__init__() | |
def get_azure_client( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
) -> AzureOpenAI: | |
if client is None: | |
azure_client_params = self.initialize_azure_sdk_client( | |
litellm_params=litellm_params or {}, | |
api_key=api_key, | |
api_base=api_base, | |
model_name="", | |
api_version=api_version, | |
is_async=False, | |
) | |
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore | |
else: | |
azure_openai_client = client | |
return azure_openai_client | |
def async_get_azure_client( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
) -> AsyncAzureOpenAI: | |
if client is None: | |
azure_client_params = self.initialize_azure_sdk_client( | |
litellm_params=litellm_params or {}, | |
api_key=api_key, | |
api_base=api_base, | |
model_name="", | |
api_version=api_version, | |
is_async=True, | |
) | |
azure_openai_client = AsyncAzureOpenAI(**azure_client_params) | |
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore | |
else: | |
azure_openai_client = client | |
return azure_openai_client | |
### ASSISTANTS ### | |
async def async_get_assistants( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
litellm_params: Optional[dict] = None, | |
) -> AsyncCursorPage[Assistant]: | |
azure_openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = await azure_openai_client.beta.assistants.list() | |
return response | |
# fmt: off | |
def get_assistants( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
aget_assistants: Literal[True], | |
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: | |
... | |
def get_assistants( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI], | |
aget_assistants: Optional[Literal[False]], | |
) -> SyncCursorPage[Assistant]: | |
... | |
# fmt: on | |
def get_assistants( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client=None, | |
aget_assistants=None, | |
litellm_params: Optional[dict] = None, | |
): | |
if aget_assistants is not None and aget_assistants is True: | |
return self.async_get_assistants( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
azure_openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
api_version=api_version, | |
litellm_params=litellm_params, | |
) | |
response = azure_openai_client.beta.assistants.list() | |
return response | |
### MESSAGES ### | |
async def a_add_message( | |
self, | |
thread_id: str, | |
message_data: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
) -> OpenAIMessage: | |
openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore | |
thread_id, **message_data # type: ignore | |
) | |
response_obj: Optional[OpenAIMessage] = None | |
if getattr(thread_message, "status", None) is None: | |
thread_message.status = "completed" | |
response_obj = OpenAIMessage(**thread_message.dict()) | |
else: | |
response_obj = OpenAIMessage(**thread_message.dict()) | |
return response_obj | |
# fmt: off | |
def add_message( | |
self, | |
thread_id: str, | |
message_data: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
a_add_message: Literal[True], | |
litellm_params: Optional[dict] = None, | |
) -> Coroutine[None, None, OpenAIMessage]: | |
... | |
def add_message( | |
self, | |
thread_id: str, | |
message_data: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI], | |
a_add_message: Optional[Literal[False]], | |
litellm_params: Optional[dict] = None, | |
) -> OpenAIMessage: | |
... | |
# fmt: on | |
def add_message( | |
self, | |
thread_id: str, | |
message_data: dict, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client=None, | |
a_add_message: Optional[bool] = None, | |
litellm_params: Optional[dict] = None, | |
): | |
if a_add_message is not None and a_add_message is True: | |
return self.a_add_message( | |
thread_id=thread_id, | |
message_data=message_data, | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore | |
thread_id, **message_data # type: ignore | |
) | |
response_obj: Optional[OpenAIMessage] = None | |
if getattr(thread_message, "status", None) is None: | |
thread_message.status = "completed" | |
response_obj = OpenAIMessage(**thread_message.dict()) | |
else: | |
response_obj = OpenAIMessage(**thread_message.dict()) | |
return response_obj | |
async def async_get_messages( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI] = None, | |
litellm_params: Optional[dict] = None, | |
) -> AsyncCursorPage[OpenAIMessage]: | |
openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = await openai_client.beta.threads.messages.list(thread_id=thread_id) | |
return response | |
# fmt: off | |
def get_messages( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
aget_messages: Literal[True], | |
litellm_params: Optional[dict] = None, | |
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: | |
... | |
def get_messages( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI], | |
aget_messages: Optional[Literal[False]], | |
litellm_params: Optional[dict] = None, | |
) -> SyncCursorPage[OpenAIMessage]: | |
... | |
# fmt: on | |
def get_messages( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client=None, | |
aget_messages=None, | |
litellm_params: Optional[dict] = None, | |
): | |
if aget_messages is not None and aget_messages is True: | |
return self.async_get_messages( | |
thread_id=thread_id, | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = openai_client.beta.threads.messages.list(thread_id=thread_id) | |
return response | |
### THREADS ### | |
async def async_create_thread( | |
self, | |
metadata: Optional[dict], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
litellm_params: Optional[dict] = None, | |
) -> Thread: | |
openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
data = {} | |
if messages is not None: | |
data["messages"] = messages # type: ignore | |
if metadata is not None: | |
data["metadata"] = metadata # type: ignore | |
message_thread = await openai_client.beta.threads.create(**data) # type: ignore | |
return Thread(**message_thread.dict()) | |
# fmt: off | |
def create_thread( | |
self, | |
metadata: Optional[dict], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
client: Optional[AsyncAzureOpenAI], | |
acreate_thread: Literal[True], | |
litellm_params: Optional[dict] = None, | |
) -> Coroutine[None, None, Thread]: | |
... | |
def create_thread( | |
self, | |
metadata: Optional[dict], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
client: Optional[AzureOpenAI], | |
acreate_thread: Optional[Literal[False]], | |
litellm_params: Optional[dict] = None, | |
) -> Thread: | |
... | |
# fmt: on | |
def create_thread( | |
self, | |
metadata: Optional[dict], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
client=None, | |
acreate_thread=None, | |
litellm_params: Optional[dict] = None, | |
): | |
""" | |
Here's an example: | |
``` | |
from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData | |
# create thread | |
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} | |
openai_api.create_thread(messages=[message]) | |
``` | |
""" | |
if acreate_thread is not None and acreate_thread is True: | |
return self.async_create_thread( | |
metadata=metadata, | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
messages=messages, | |
litellm_params=litellm_params, | |
) | |
azure_openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
data = {} | |
if messages is not None: | |
data["messages"] = messages # type: ignore | |
if metadata is not None: | |
data["metadata"] = metadata # type: ignore | |
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore | |
return Thread(**message_thread.dict()) | |
async def async_get_thread( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
litellm_params: Optional[dict] = None, | |
) -> Thread: | |
openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = await openai_client.beta.threads.retrieve(thread_id=thread_id) | |
return Thread(**response.dict()) | |
# fmt: off | |
def get_thread( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
aget_thread: Literal[True], | |
litellm_params: Optional[dict] = None, | |
) -> Coroutine[None, None, Thread]: | |
... | |
def get_thread( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI], | |
aget_thread: Optional[Literal[False]], | |
litellm_params: Optional[dict] = None, | |
) -> Thread: | |
... | |
# fmt: on | |
def get_thread( | |
self, | |
thread_id: str, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client=None, | |
aget_thread=None, | |
litellm_params: Optional[dict] = None, | |
): | |
if aget_thread is not None and aget_thread is True: | |
return self.async_get_thread( | |
thread_id=thread_id, | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = openai_client.beta.threads.retrieve(thread_id=thread_id) | |
return Thread(**response.dict()) | |
# def delete_thread(self): | |
# pass | |
### RUNS ### | |
async def arun_thread( | |
self, | |
thread_id: str, | |
assistant_id: str, | |
additional_instructions: Optional[str], | |
instructions: Optional[str], | |
metadata: Optional[Dict], | |
model: Optional[str], | |
stream: Optional[bool], | |
tools: Optional[Iterable[AssistantToolParam]], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
litellm_params: Optional[dict] = None, | |
) -> Run: | |
openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
timeout=timeout, | |
max_retries=max_retries, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
additional_instructions=additional_instructions, | |
instructions=instructions, | |
metadata=metadata, # type: ignore | |
model=model, | |
tools=tools, | |
) | |
return response | |
def async_run_thread_stream( | |
self, | |
client: AsyncAzureOpenAI, | |
thread_id: str, | |
assistant_id: str, | |
additional_instructions: Optional[str], | |
instructions: Optional[str], | |
metadata: Optional[Dict], | |
model: Optional[str], | |
tools: Optional[Iterable[AssistantToolParam]], | |
event_handler: Optional[AssistantEventHandler], | |
litellm_params: Optional[dict] = None, | |
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: | |
data: Dict[str, Any] = { | |
"thread_id": thread_id, | |
"assistant_id": assistant_id, | |
"additional_instructions": additional_instructions, | |
"instructions": instructions, | |
"metadata": metadata, | |
"model": model, | |
"tools": tools, | |
} | |
if event_handler is not None: | |
data["event_handler"] = event_handler | |
return client.beta.threads.runs.stream(**data) # type: ignore | |
def run_thread_stream( | |
self, | |
client: AzureOpenAI, | |
thread_id: str, | |
assistant_id: str, | |
additional_instructions: Optional[str], | |
instructions: Optional[str], | |
metadata: Optional[Dict], | |
model: Optional[str], | |
tools: Optional[Iterable[AssistantToolParam]], | |
event_handler: Optional[AssistantEventHandler], | |
litellm_params: Optional[dict] = None, | |
) -> AssistantStreamManager[AssistantEventHandler]: | |
data: Dict[str, Any] = { | |
"thread_id": thread_id, | |
"assistant_id": assistant_id, | |
"additional_instructions": additional_instructions, | |
"instructions": instructions, | |
"metadata": metadata, | |
"model": model, | |
"tools": tools, | |
} | |
if event_handler is not None: | |
data["event_handler"] = event_handler | |
return client.beta.threads.runs.stream(**data) # type: ignore | |
# fmt: off | |
def run_thread( | |
self, | |
thread_id: str, | |
assistant_id: str, | |
additional_instructions: Optional[str], | |
instructions: Optional[str], | |
metadata: Optional[Dict], | |
model: Optional[str], | |
stream: Optional[bool], | |
tools: Optional[Iterable[AssistantToolParam]], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
arun_thread: Literal[True], | |
) -> Coroutine[None, None, Run]: | |
... | |
def run_thread( | |
self, | |
thread_id: str, | |
assistant_id: str, | |
additional_instructions: Optional[str], | |
instructions: Optional[str], | |
metadata: Optional[Dict], | |
model: Optional[str], | |
stream: Optional[bool], | |
tools: Optional[Iterable[AssistantToolParam]], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AzureOpenAI], | |
arun_thread: Optional[Literal[False]], | |
) -> Run: | |
... | |
# fmt: on | |
def run_thread( | |
self, | |
thread_id: str, | |
assistant_id: str, | |
additional_instructions: Optional[str], | |
instructions: Optional[str], | |
metadata: Optional[Dict], | |
model: Optional[str], | |
stream: Optional[bool], | |
tools: Optional[Iterable[AssistantToolParam]], | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client=None, | |
arun_thread=None, | |
event_handler: Optional[AssistantEventHandler] = None, | |
litellm_params: Optional[dict] = None, | |
): | |
if arun_thread is not None and arun_thread is True: | |
if stream is not None and stream is True: | |
azure_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
return self.async_run_thread_stream( | |
client=azure_client, | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
additional_instructions=additional_instructions, | |
instructions=instructions, | |
metadata=metadata, | |
model=model, | |
tools=tools, | |
event_handler=event_handler, | |
litellm_params=litellm_params, | |
) | |
return self.arun_thread( | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
additional_instructions=additional_instructions, | |
instructions=instructions, | |
metadata=metadata, # type: ignore | |
model=model, | |
stream=stream, | |
tools=tools, | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
if stream is not None and stream is True: | |
return self.run_thread_stream( | |
client=openai_client, | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
additional_instructions=additional_instructions, | |
instructions=instructions, | |
metadata=metadata, | |
model=model, | |
tools=tools, | |
event_handler=event_handler, | |
litellm_params=litellm_params, | |
) | |
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
additional_instructions=additional_instructions, | |
instructions=instructions, | |
metadata=metadata, # type: ignore | |
model=model, | |
tools=tools, | |
) | |
return response | |
# Create Assistant | |
async def async_create_assistants( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
create_assistant_data: dict, | |
litellm_params: Optional[dict] = None, | |
) -> Assistant: | |
azure_openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = await azure_openai_client.beta.assistants.create( | |
**create_assistant_data | |
) | |
return response | |
def create_assistants( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
create_assistant_data: dict, | |
client=None, | |
async_create_assistants=None, | |
litellm_params: Optional[dict] = None, | |
): | |
if async_create_assistants is not None and async_create_assistants is True: | |
return self.async_create_assistants( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
create_assistant_data=create_assistant_data, | |
litellm_params=litellm_params, | |
) | |
azure_openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = azure_openai_client.beta.assistants.create(**create_assistant_data) | |
return response | |
# Delete Assistant | |
async def async_delete_assistant( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
client: Optional[AsyncAzureOpenAI], | |
assistant_id: str, | |
litellm_params: Optional[dict] = None, | |
): | |
azure_openai_client = self.async_get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = await azure_openai_client.beta.assistants.delete( | |
assistant_id=assistant_id | |
) | |
return response | |
def delete_assistant( | |
self, | |
api_key: Optional[str], | |
api_base: Optional[str], | |
api_version: Optional[str], | |
azure_ad_token: Optional[str], | |
timeout: Union[float, httpx.Timeout], | |
max_retries: Optional[int], | |
assistant_id: str, | |
async_delete_assistants: Optional[bool] = None, | |
client=None, | |
litellm_params: Optional[dict] = None, | |
): | |
if async_delete_assistants is not None and async_delete_assistants is True: | |
return self.async_delete_assistant( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
assistant_id=assistant_id, | |
litellm_params=litellm_params, | |
) | |
azure_openai_client = self.get_azure_client( | |
api_key=api_key, | |
api_base=api_base, | |
api_version=api_version, | |
azure_ad_token=azure_ad_token, | |
timeout=timeout, | |
max_retries=max_retries, | |
client=client, | |
litellm_params=litellm_params, | |
) | |
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id) | |
return response | |