File size: 3,614 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
import uuid

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import litellm
from litellm import transcription
from litellm.litellm_core_utils.get_supported_openai_params import (
    get_supported_openai_params,
)
from litellm.llms.base_llm.audio_transcription.transformation import (
    BaseAudioTranscriptionConfig,
)
from litellm.utils import ProviderConfigManager
from abc import ABC, abstractmethod

pwd = os.path.dirname(os.path.realpath(__file__))
print(pwd)

file_path = os.path.join(pwd, "gettysburg.wav")

audio_file = open(file_path, "rb")


class BaseLLMAudioTranscriptionTest(ABC):
    @abstractmethod
    def get_base_audio_transcription_call_args(self) -> dict:
        """Must return the base audio transcription call args"""
        pass

    @abstractmethod
    def get_custom_llm_provider(self) -> litellm.LlmProviders:
        """Must return the custom llm provider"""
        pass

    def test_audio_transcription(self):
        """
        Test that the audio transcription is translated correctly.
        """
        litellm.set_verbose = True
        transcription_call_args = self.get_base_audio_transcription_call_args()
        transcript = transcription(**transcription_call_args, file=audio_file)
        print(f"transcript: {transcript.model_dump()}")
        print(f"transcript hidden params: {transcript._hidden_params}")

        assert transcript.text is not None

    @pytest.mark.asyncio
    async def test_audio_transcription_async(self):
        """
        Test that the audio transcription is translated correctly.
        """

        litellm.set_verbose = True
        litellm._turn_on_debug()
        AUDIO_FILE = open(file_path, "rb")
        transcription_call_args = self.get_base_audio_transcription_call_args()
        transcript = await litellm.atranscription(**transcription_call_args, file=AUDIO_FILE)
        print(f"transcript: {transcript.model_dump()}")
        print(f"transcript hidden params: {transcript._hidden_params}")

        assert transcript.text is not None

    def test_audio_transcription_optional_params(self):
        """
        Test that the audio transcription is translated correctly.
        """
        transcription_args = self.get_base_audio_transcription_call_args()
        model = transcription_args["model"]
        custom_llm_provider = self.get_custom_llm_provider()
        optional_params = get_supported_openai_params(
            model=model,
            custom_llm_provider=custom_llm_provider.value,
            request_type="transcription",
        )
        print(f"optional_params: {optional_params}")
        assert optional_params is not None
        assert (
            "max_completion_tokens" not in optional_params
        )  # assert default chat completion response not returned

    def test_audio_transcription_config(self):
        """
        Test that the audio transcription config is implemented and correctly instrumented.
        """
        transcription_args = self.get_base_audio_transcription_call_args()
        model = transcription_args["model"]
        custom_llm_provider = self.get_custom_llm_provider()
        config = ProviderConfigManager.get_provider_audio_transcription_config(
            model=model,
            provider=custom_llm_provider,
        )
        print(f"config: {config}")
        assert config is not None
        assert isinstance(config, BaseAudioTranscriptionConfig)