|
import uuid |
|
from typing import Any, Optional |
|
|
|
from openai import AsyncAzureOpenAI, AzureOpenAI |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name |
|
from litellm.types.utils import FileTypes |
|
from litellm.utils import TranscriptionResponse, convert_to_model_response_object |
|
|
|
from .azure import ( |
|
AzureChatCompletion, |
|
get_azure_ad_token_from_oidc, |
|
select_azure_base_url_or_endpoint, |
|
) |
|
|
|
|
|
class AzureAudioTranscription(AzureChatCompletion): |
|
def audio_transcriptions( |
|
self, |
|
model: str, |
|
audio_file: FileTypes, |
|
optional_params: dict, |
|
logging_obj: Any, |
|
model_response: TranscriptionResponse, |
|
timeout: float, |
|
max_retries: int, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
api_version: Optional[str] = None, |
|
client=None, |
|
azure_ad_token: Optional[str] = None, |
|
atranscription: bool = False, |
|
) -> TranscriptionResponse: |
|
data = {"model": model, "file": audio_file, **optional_params} |
|
|
|
|
|
azure_client_params = { |
|
"api_version": api_version, |
|
"azure_endpoint": api_base, |
|
"azure_deployment": model, |
|
"timeout": timeout, |
|
} |
|
|
|
azure_client_params = select_azure_base_url_or_endpoint( |
|
azure_client_params=azure_client_params |
|
) |
|
if api_key is not None: |
|
azure_client_params["api_key"] = api_key |
|
elif azure_ad_token is not None: |
|
if azure_ad_token.startswith("oidc/"): |
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) |
|
azure_client_params["azure_ad_token"] = azure_ad_token |
|
|
|
if max_retries is not None: |
|
azure_client_params["max_retries"] = max_retries |
|
|
|
if atranscription is True: |
|
return self.async_audio_transcriptions( |
|
audio_file=audio_file, |
|
data=data, |
|
model_response=model_response, |
|
timeout=timeout, |
|
api_key=api_key, |
|
api_base=api_base, |
|
client=client, |
|
azure_client_params=azure_client_params, |
|
max_retries=max_retries, |
|
logging_obj=logging_obj, |
|
) |
|
if client is None: |
|
azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) |
|
else: |
|
azure_client = client |
|
|
|
|
|
logging_obj.pre_call( |
|
input=f"audio_file_{uuid.uuid4()}", |
|
api_key=azure_client.api_key, |
|
additional_args={ |
|
"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, |
|
"api_base": azure_client._base_url._uri_reference, |
|
"atranscription": True, |
|
"complete_input_dict": data, |
|
}, |
|
) |
|
|
|
response = azure_client.audio.transcriptions.create( |
|
**data, timeout=timeout |
|
) |
|
|
|
if isinstance(response, BaseModel): |
|
stringified_response = response.model_dump() |
|
else: |
|
stringified_response = TranscriptionResponse(text=response).model_dump() |
|
|
|
|
|
logging_obj.post_call( |
|
input=get_audio_file_name(audio_file), |
|
api_key=api_key, |
|
additional_args={"complete_input_dict": data}, |
|
original_response=stringified_response, |
|
) |
|
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} |
|
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") |
|
return final_response |
|
|
|
async def async_audio_transcriptions( |
|
self, |
|
audio_file: FileTypes, |
|
data: dict, |
|
model_response: TranscriptionResponse, |
|
timeout: float, |
|
azure_client_params: dict, |
|
logging_obj: Any, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
client=None, |
|
max_retries=None, |
|
): |
|
response = None |
|
try: |
|
if client is None: |
|
async_azure_client = AsyncAzureOpenAI( |
|
**azure_client_params, |
|
http_client=litellm.aclient_session, |
|
) |
|
else: |
|
async_azure_client = client |
|
|
|
|
|
logging_obj.pre_call( |
|
input=f"audio_file_{uuid.uuid4()}", |
|
api_key=async_azure_client.api_key, |
|
additional_args={ |
|
"headers": { |
|
"Authorization": f"Bearer {async_azure_client.api_key}" |
|
}, |
|
"api_base": async_azure_client._base_url._uri_reference, |
|
"atranscription": True, |
|
"complete_input_dict": data, |
|
}, |
|
) |
|
|
|
raw_response = ( |
|
await async_azure_client.audio.transcriptions.with_raw_response.create( |
|
**data, timeout=timeout |
|
) |
|
) |
|
|
|
headers = dict(raw_response.headers) |
|
response = raw_response.parse() |
|
|
|
if isinstance(response, BaseModel): |
|
stringified_response = response.model_dump() |
|
else: |
|
stringified_response = TranscriptionResponse(text=response).model_dump() |
|
|
|
|
|
logging_obj.post_call( |
|
input=get_audio_file_name(audio_file), |
|
api_key=api_key, |
|
additional_args={ |
|
"headers": { |
|
"Authorization": f"Bearer {async_azure_client.api_key}" |
|
}, |
|
"api_base": async_azure_client._base_url._uri_reference, |
|
"atranscription": True, |
|
"complete_input_dict": data, |
|
}, |
|
original_response=stringified_response, |
|
) |
|
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} |
|
response = convert_to_model_response_object( |
|
_response_headers=headers, |
|
response_object=stringified_response, |
|
model_response_object=model_response, |
|
hidden_params=hidden_params, |
|
response_type="audio_transcription", |
|
) |
|
return response |
|
except Exception as e: |
|
|
|
logging_obj.post_call( |
|
input=input, |
|
api_key=api_key, |
|
original_response=str(e), |
|
) |
|
raise e |
|
|